import itertools
import os

from matplotlib import pyplot as plt
import numpy as np
import scipy
import torch
from torch import nn
import torch.utils.data
from tqdm import tqdm


from extras.parse_args import args

if args.wandb:
    import wandb
else:
    wandb = None

import extras.logger
import extras.source
from extras.wandb_utils import init_wandb, finish_wandb
from extras.gpu_stats_server import run_gpu_stats_server, fetch_gpu_stats

from solver.ode_forward import ode_forward
from solver.ode_layer import ODEINDLayer


is_finetune: bool = False


def get_basis_vars(n_vars: int = 3, polynomial_order: int = 2) -> tuple:
    return tuple(itertools.chain.from_iterable(
        itertools.combinations_with_replacement(range(n_vars), r=r)
        for r in range(polynomial_order + 1)
    ))  # tuple[tuple] 10


def compute_basis(x: torch.Tensor, basis_vars: tuple) -> torch.Tensor:
    return torch.stack([x[..., basis_var].prod(dim=-1) for basis_var in basis_vars], dim=-1)  # (..., 10)


def generate_lorenz(step_size: float, n_steps: int) -> np.ndarray:
    rho, sigma, beta = 28., 10., 8. / 3.

    def f(state, t):
        x, y, z = state
        return sigma * (y - x), x * (rho - z) - y, x * y - beta * z

    state_0 = np.array([1., 1., 1.])
    time_steps = np.linspace(0., step_size * n_steps, n_steps)
    x_train = scipy.integrate.odeint(f, state_0, time_steps)
    return x_train


def plot_lorenz(data, save_path):
    f = plt.figure(figsize=(6, 4))
    plt.plot(data[:, 0], data[:, 2], label='Trajectory')
    plt.plot(data[0, 0], data[0, 2], 'ko', label='Initialization')
    plt.xlim(-20., 20.)
    plt.ylim(0., 50.)
    plt.tick_params(axis='both', which='both', length=0.)
    plt.grid()
    # plt.gca().set_aspect('equal')
    plt.legend(loc='upper center', framealpha=1.)
    plt.gca().set_facecolor((1., 1., 1., 1.))
    plt.gcf().set_facecolor((1., 1., 1., 0.))
    plt.tight_layout()
    print(f'Saving plot {save_path}')
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(f)


class LorenzDataset(torch.utils.data.Dataset):
    def __init__(self, step_size: float, n_steps_per_batch: int = 100, n_steps: int = 1000, device: torch.device = torch.device('cpu')):
        self.n_steps_per_batch = n_steps_per_batch
        self.n_steps = n_steps
        self.x_train: torch.Tensor = torch.as_tensor(generate_lorenz(step_size, n_steps + n_steps_per_batch), device=device)  # tensor (n_steps, 3)

    def __len__(self):
        return self.n_steps

    def __getitem__(self, i: int):
        return self.x_train[i:i+self.n_steps_per_batch]


class Model(nn.Module):
    def __init__(self, n_basis, n_step_per_batch, device, batch_size):
        super().__init__()
        dtype = torch.float64
        n_ind_dim = 3
        order = 2

        self.step_size = torch.full((1, 1, 1), 1e-3 if is_finetune else 1e-2, dtype=dtype, device=device)
        self.coeffs = torch.zeros(1, n_ind_dim, 1, 1, 1, order+1, dtype=dtype, device=device)  # use n_ind_dim instead of 1 to avoid repeating later (slow)
        self.coeffs[..., 1] = 1.

        mask = torch.ones(1, n_basis, n_ind_dim, dtype=dtype, device=device)  # (1, 10, 3)
        self.register_buffer('mask', mask)

        self.param_in = nn.Parameter(torch.randn(1, 64))
        self.param_net = nn.Sequential(
            nn.Linear(64, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, n_basis * n_ind_dim),
            nn.Unflatten(dim=-1, unflattened_size=(n_basis, n_ind_dim)),
        )

        self.net = nn.Sequential(
            nn.Flatten(start_dim=-2, end_dim=-1),
            nn.Linear(n_step_per_batch * n_ind_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, n_step_per_batch * n_ind_dim),
            nn.Unflatten(dim=-1, unflattened_size=(n_step_per_batch, n_ind_dim))
        )

        self.ode_ind_layer = ODEINDLayer(
            bs=batch_size,
            order=order,
            n_ind_dim=3,
            n_iv=1,
            n_step=n_step_per_batch,
            n_iv_steps=1,
            solver_dbl=True,
            gamma=.05,
            alpha=0.,
            double_ret=True,
            # device=device,
        )


    def update_mask(self, mask):
        self.mask *= mask

    def get_xi(self):
        return self.param_net(self.param_in) * self.mask  # (1, 10, 3)

    def forward(self, batch_in):
        n_steps_per_batch = 5
        # n_steps_per_batch = 5  # XXX

        var = self.net(batch_in)  # (bs, 50, 3)
        # var = batch_in
        rhs = compute_basis(var, get_basis_vars()) @ self.get_xi()  # (bs, 50, 3)

        if args.solver == 'LEAST_SQUARES':
            u = ode_forward(
                self.coeffs[None, :, :, 0],  # (1, 1, 3, 1, 1, 3)
                rhs.transpose(0, 1)[..., None],  # (50, bs, 3, 1)
                var[None, :, 0, :, None, None],  # (1, bs, 3, 1, 1)
                self.step_size,
                n_steps=batch_in.size(-2),
                is_step_dim_first=True,
                enable_central_smoothness=args.central_diff,
                enable_freeze_lhs=True,
            )  # (50, bs, 3, 1, 3)
            x0 = u[..., 0, 0].transpose(0, 1)
        else:
            x0, x1, x2, eps, steps = self.ode_ind_layer(
                self.coeffs.expand(rhs.size(0), 3, 1, n_steps_per_batch, 1, 3).contiguous(),
                rhs.transpose(-2, -1)[..., None, :],  # (bs, 3, 1, 50)
                var[:, 0, :, None, None, None],
                self.step_size.expand(rhs.size(0), 3, n_steps_per_batch - 1).contiguous(),
            )
            x0 = x0.transpose(-2, -1)

        return x0, var


def print_eq(model, basis_vars, logger, stdout=False):
    # print learned equation
    xi = model.get_xi().detach().squeeze()
    n_basis, dim = xi.shape
    basis_vars = ['1'] + ['*'.join([f'x{v}' for v in basis_var]) for basis_var in basis_vars[1:]]
    for i in range(dim):
        k = f'dx{i}'
        v = "0 "
        for j in range(n_basis):
            v += f' + {xi[j, i].item()} * {basis_vars[j]}'
        logger.info(f'{k} = {v}')
        if stdout:
            print(f'{k} = {v}')


def simulate(n_steps: int, step: float, model, basis_vars) -> np.ndarray:
    # simulate learned equation
    xi = model.get_xi()[0].detach().t().cpu().numpy()  # (1, 10, 3) (3, 10)

    def f(state, t):
        return xi @ np.array([np.prod(state[basis_var,]) for basis_var in basis_vars])

    state_0 = np.array([1., 1., 1.])
    time_steps = np.linspace(0., step * n_steps, n_steps)
    x_sim = scipy.integrate.odeint(f, state_0, time_steps)
    return x_sim


def main():
    print(args)
    p = run_gpu_stats_server(args.gpu_stats_port)

    if wandb:
        wandb_login_success = init_wandb(args=args, experiment_name=args.log_dir)
        assert wandb_login_success

    log_dir, run_id = extras.source.create_log_dir(root=args.log_dir)
    logger = extras.logger.setup(log_dir, stdout=False)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    step_size = 1e-3 if is_finetune else 1e-2
    n_steps = 10000
    n_steps_per_batch = 5
    # n_steps_per_batch = 5  # XXX
    batch_size = 512
    # batch_size = 64  # XXX
    # weights less than threshold (absolute) are set to 0 after each optimization step.
    threshold = .1

    ds = LorenzDataset(step_size=step_size, n_steps=n_steps, n_steps_per_batch=n_steps_per_batch)
    train_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)

    # plot train data
    plot_lorenz(ds.x_train.cpu().numpy(), os.path.join(log_dir, 'train.pdf'))
    np.save(os.path.join(log_dir, f'train.npy'), ds.x_train.cpu().numpy())

    basis_vars = get_basis_vars()
    model = Model(n_basis=len(basis_vars), n_step_per_batch=n_steps_per_batch, device=device, batch_size=batch_size)
    model.double().to(device)

    if is_finetune:
        state_dict = torch.load('logs/lorenz/lorenz_step5/lorenz_ls/1/model_10.ckpt', map_location=device, weights_only=True)
        model.load_state_dict(state_dict)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4 if is_finetune else 1e-5)

    """Optimize and threshold cycle"""
    max_iter = 100 if is_finetune else 10
    for step in range(max_iter + 1):
        if step > 0:
            print(f'Optimizer iteration {step}/{max_iter}')

            # threshold
            if step > 1:
                xi = model.get_xi()
                mask = (xi.abs() > threshold).float()
                logger.info(xi)
                logger.info(model.mask)
                logger.info(model.mask * mask)
                model.update_mask(mask)  # set mask

            n_epochs = 400
            with tqdm(total=n_epochs) as pbar:
                for epoch in range(n_epochs):
                    pbar.update(1)
                    for i, batch_in in enumerate(train_loader):
                        batch_in = batch_in.to(device)
                        x0, var = model(batch_in)

                        x_loss = (x0 - batch_in).pow(2).mean()
                        loss = x_loss + (var - batch_in).pow(2).mean()

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    gpu_stats = fetch_gpu_stats(args.gpu_stats_port)
                    if gpu_stats is not None:
                        gpu_utilization = gpu_stats[device.index]['utilization']
                        gpu_memory_used = gpu_stats[device.index]['memory_used']
                    else:
                        gpu_utilization = gpu_memory_used = -1.

                    stats = {'run': run_id, 'iter': step, 'epoch': epoch, 'loss': loss.item(), 'xloss': x_loss.item(), 'gpu_utilization': gpu_utilization, 'gpu_memory_used': gpu_memory_used}
                    logger.info(' '.join([f'{k} {v}' for k, v in stats.items()]))
                    pbar.set_description(' '.join([f'{k} {v}' for k, v in stats.items()]))
                    if wandb:
                        wandb.log(stats, step=(step - 1) * n_epochs + epoch)

        # simulate and plot
        torch.save(model.state_dict(), os.path.join(log_dir, f'model_{step}.ckpt'))
        print_eq(model, basis_vars, logger, stdout=True)
        x_sim = simulate(n_steps, step_size, model, basis_vars)  # (10000, 3)
        plot_lorenz(x_sim, os.path.join(log_dir, f'sim_{step}.pdf'))
        np.save(os.path.join(log_dir, f'sim_{step}.npy'), x_sim)

    if wandb:
        finish_wandb()
    p.kill()


if __name__ == '__main__':
    main()
