import os

import h5py
import numpy as np
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


from extras.parse_args import args

if args.wandb:
    import wandb
else:
    wandb = None

from extras.wandb_utils import init_wandb, finish_wandb
from extras.gpu_stats_server import run_gpu_stats_server, fetch_gpu_stats

from pde.resnet1d import ResNet1D
from solver.ode_layer import ODESYSLayer


DBL = False
steps_per_example = 10
batch_size = 32
eval_freq = 5


class KDVDataset(Dataset):
    def __init__(self, kind='train', device=torch.device('cpu')):
        self.steps_per_example = steps_per_example
        data_root = 'datasets/KDV_easy'
        if kind=='train':
            file = os.path.join(data_root,'KdV_train_512_easy.h5')
            h5_path = 'train/pde_140-256'
        elif kind=='valid':
            file = os.path.join(data_root,'KdV_valid_easy.h5')
            h5_path = 'valid/pde_140-256'
        elif kind=='test':
            file = os.path.join(data_root,'KdV_test_easy.h5')
            h5_path = 'test/pde_140-256'
        else:
            raise ValueError('Invalid dataset type') 

        data = np.array(h5py.File(file, 'r').get(h5_path))  # float64, (512, 140, 256)
        self.data = torch.as_tensor(data, dtype=torch.float64 if DBL else torch.float32, device=device)
        self.n_trajectory = self.data.shape[0]  # 512
        self.n_time_steps = self.data.shape[1]  # 140
        print('data module ', kind, self.data.shape, self.n_time_steps)
        self.length_per_trajectory = self.n_time_steps - self.steps_per_example * 2 + 2

    def __len__(self):
        return self.length_per_trajectory * self.n_trajectory

    def __getitem__(self, idx):
        traj_idx = idx // self.length_per_trajectory
        time_idx = idx % self.length_per_trajectory

        x = self.data[traj_idx, time_idx : time_idx + self.steps_per_example]
        y = self.data[traj_idx, time_idx + self.steps_per_example - 1 : time_idx + self.steps_per_example * 2 - 1]
        return x, y


class KDVDataModule(pl.LightningDataModule):
    def __init__(self, device=torch.device('cpu')):
        super().__init__()
        self.train_dataset = KDVDataset(kind='train', device=device)
        self.valid_dataset = KDVDataset(kind='valid', device=device)
        self.test_dataset = KDVDataset(kind='test', device=device)
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, drop_last=True, shuffle=False, num_workers=0)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)


class Method(pl.LightningModule):
    def __init__(self):
        super().__init__()
        step_size_t = .05
        order = 4

        n_ind_dim = 256
        self.n_step = steps_per_example

        self.ode = ODESYSLayer(
            bs=batch_size,
            n_ind_dim=n_ind_dim,
            order=order,
            n_equations=1,
            gamma=.05,
            alpha=0.,
            n_dim=1,
            n_iv=order,
            n_step=self.n_step,
            n_iv_steps=1,
            solver_dbl=True,
            double_ret=DBL,
        )

        pm = 'circular'

        self.cf_cnn = nn.Sequential(
            ResNet1D(
                in_channels=10,
                base_filters=64,
                kernel_size=9,
                stride=1,
                groups=1,
                n_block=10,
                n_classes=2,
                use_bn=False,
                use_do=False,
            ),
            nn.Conv1d(256, self.n_step * (order + 1), kernel_size=7, padding=3, stride=1, padding_mode=pm),
        )

        self.rhs_cnn = nn.Sequential(
            ResNet1D(
                in_channels=10,
                base_filters=64,
                kernel_size=9,
                stride=1, groups=1,
                n_block=10,
                n_classes=2,
                use_bn=False,
                use_do=False,
            ),
            nn.Conv1d(256, self.n_step, kernel_size=7, padding=3, stride=1, padding_mode=pm),
        )

        self.iv_cnn = nn.Sequential(
            nn.Conv1d(10, 64, kernel_size=7, padding=3, stride=1, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=7, padding=3, stride=1, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=7, padding=3, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=7, padding=3, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(256, 128, kernel_size=7, padding=3, stride=2, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(128, 64, kernel_size=7, padding=3, stride=2, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(64, 32, kernel_size=7, padding=3, stride=1, padding_mode=pm),
            nn.Flatten(),
            nn.Linear(32 * 64, n_ind_dim * (order - 1)),
            nn.Unflatten(dim=-1, unflattened_size=[n_ind_dim, order - 1]),
        )

        steps_layer_t = nn.Linear(32 * 64, 1)
        steps_layer_t.weight.data.fill_(0.)
        steps_layer_t.bias.data.fill_(torch.tensor(step_size_t).logit())  # set step bias to set initial step
        self.step_cnn = nn.Sequential(
            nn.Conv1d(10, 64, kernel_size=5, padding=2, stride=1, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=5, padding=2, stride=2, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=5, padding=2, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=5, padding=2, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(256, 128, kernel_size=5, padding=2, stride=1, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(128, 64, kernel_size=5, padding=2, stride=2, padding_mode=pm),
            nn.ReLU(),
            nn.Conv1d(64, 32, kernel_size=5, padding=2, stride=1, padding_mode=pm),
            nn.Flatten(),
            steps_layer_t,
            nn.Sigmoid(),
        )

    def forward(self, x):
        # x: (32, 10, 256)
        coeffs = self.cf_cnn(x).permute(0, 2, 1)  # (32, 256, 50)
        rhs = self.rhs_cnn(x).permute(0, 2, 1)  # (32, 256, 10)
        steps_t = self.step_cnn(x).clip(min=1e-3)[:, None, :].repeat(1, coeffs.size(-2), self.n_step - 1)  # (32, 256, 9)

        iv_t = torch.cat([
            x[:, -1, :, None],  # (32, 256, 1)
            self.iv_cnn(x),    # (32, 256, 3)
        ], dim=-1)  # (32, 256, 4)

        u, *_ = self.ode(coeffs, rhs, iv_t, steps_t)  # (32, 256, 10, 1)
        u = u[..., 0].transpose(-2, -1)  # (32, 10, 256)
        return u

    def loss(self, batch):
        x, y = batch[0], batch[1]
        u = self.forward(x)
        loss = (u - y).abs().sum(dim=-1).mean()
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.loss(batch)
        self.log('train_loss', loss, prog_bar=True, logger=True, on_epoch=True, on_step=True)

        if batch_idx % 10 == 0:
            gpu_stats = fetch_gpu_stats(args.gpu_stats_port)
            if gpu_stats is not None:
                gpu_utilization = gpu_stats[loss.device.index]['utilization']
                gpu_memory_used = gpu_stats[loss.device.index]['memory_used']
            else:
                gpu_utilization = gpu_memory_used = -1.

            if wandb:
                wandb.log({'step': batch_idx, 'loss': loss.item(), 'gpu_utilization': gpu_utilization, 'gpu_memory_used': gpu_memory_used})

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self.loss(batch)
        self.log('val_loss', loss, prog_bar=True, logger=True, on_epoch=True)
        return {"loss": loss}

    def test_step(self, batch, batch_idx):
        loss = self.loss(batch)
        self.log('test_loss', loss, prog_bar=True, logger=True, on_epoch=True, on_step=True)
        return {"loss": loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        lr_scheduler_config = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=.5, patience=2),
            "interval": "epoch",
            "frequency": eval_freq,
            "monitor": "val_loss",
            "strict": True,
            "name": None,
        }
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}


datamodule = KDVDataModule(device=torch.device('cpu'))


def train(ckpt: str = ''):
    print(args)
    p = run_gpu_stats_server(args.gpu_stats_port)

    if wandb:
        wandb_login_success = init_wandb(args=args, experiment_name=args.log_dir)
        assert wandb_login_success

    trainer = pl.Trainer(
        max_epochs=800,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        check_val_every_n_epoch=eval_freq,
        devices=1,
        callbacks=[
            pl.callbacks.ModelCheckpoint(filename='{epoch:04}-{step:06}-{val_loss:.6f}', monitor='val_loss', save_last=True, mode='min'),
        ],
        log_every_n_steps=1,
        default_root_dir=args.log_dir
    )

    if ckpt:
        method = Method.load_from_checkpoint(ckpt)
    else:
        method = Method()
        if DBL:
            method.double()
    trainer.fit(method, datamodule=datamodule)

    if wandb:
        finish_wandb()
    p.kill()


def evaluate(ckpt: str, save_name: str):
    true_trajectory = datamodule.test_dataset.data.cuda()
    torch.save(true_trajectory.cpu(), f'logs/kdv_eval/kdv_gt.pth')

    method = Method.load_from_checkpoint(ckpt)
    method.eval()

    n_step = 16

    trajectories = []
    with torch.no_grad():
        for i in tqdm(range(0, 512, batch_size)):
            batch_out_list =[]
            x = true_trajectory[i:i+batch_size, :steps_per_example]
            for _ in range(n_step):
                batch_out_list.append(x[:, :-1])
                x = method(x)
            batch_trajectory = torch.cat(batch_out_list, dim=1)[:, :140]
            trajectories.append(batch_trajectory)
    trajectories = torch.cat(trajectories, dim=0)
    torch.save(trajectories.cpu(), f'logs/kdv_eval/{save_name}.pth')


def get_losses(save_names: list[str], length: int = 100):
    gt = torch.load(f'logs/kdv_eval/kdv_gt.pth', weights_only=True)

    result = {}
    for save_name in save_names:
        pred = torch.load(f'logs/kdv_eval/{save_name}.pth', weights_only=True)

        true_x = gt[:, steps_per_example: steps_per_example + length]
        prediction = pred[:, steps_per_example: steps_per_example + length]

        loss = (prediction - true_x).pow(2).mean(dim=-1) / true_x.pow(2).mean(dim=-1)
        result[save_name] = loss.mean().item()
    return result


if __name__ == '__main__':
    train()
