from typing import List, Tuple, Optional, Union
import numpy as np

from torch import nn

from .base import RUNNERS, BaseRunner

from typing import Optional, List
import datetime
import copy
import time
import json
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from utilsd import use_cuda
from utilsd.earlystop import EarlyStopStatus

from ..common.function import printt
from ..common.utils import AverageMeter, GlobalTracker, to_torch

from SeqSNN.module.pelif import PELIFNode

@RUNNERS.register_module("ts", inherit=True)
class TS(BaseRunner):
    def __init__(
        self,
        task: str,
        out_ranges: Optional[List[Union[Tuple[int, int], Tuple[int, int, int]]]] = None,
        out_size: Optional[int] = None,
        aggregate: bool = True,
        **kwargs,
    ):
        """
        The model for general time-series prediction.

        Args:
            task: the prediction task, classification or regression.
            optimizer: which optimizer to use.
            lr: learning rate.
            weight_decay: L2 normlize weight
            loss_fn: loss function.
            metrics: metrics to evaluate model.
            observe: metric for model selection (earlystop).
            lower_is_better: whether a lower observed metric means better result.
            max_epoches: maximum epoch to learn.
            batch_size: batch size.
            early_stop: earlystop rounds.
            out_ranges: a list of final ranges to take as final output. Should have form [(start, end), (start, end, step), ...]
            model_path: the path to existing model parameters for continued training or finetuning
            out_size: the output size for multi-class classification or multi-variant regression task.
            aggregate: whether to aggregate across whole sequence.
        """
        self.hyper_paras = {
            "task": task,
            "out_ranges": out_ranges,
            "out_size": out_size,
            "aggregate": aggregate,
        }
        super().__init__(**kwargs)

    def _build_network(
        self,
        network,
        task: str,
        out_ranges: Optional[List[Union[Tuple[int, int, int], Tuple[int, int]]]] = None,
        out_size: Optional[int] = None,
        aggregate: bool = True,
    ) -> None:
        """Initilize the network parameters

        Args:
            task: the prediction task, classification or regression.
            out_ranges: a list of final ranges to take as final output. Should have form [(start, end), (start, end, step), ...]
            out_size: the output size for multi-class classification or multi-variant regression task.
            aggregate: whether to aggregate across whole sequence.
        """

        self.network = network
        self.aggregate = aggregate

        # Output
        if task == "classification":
            self.act_out = nn.Sigmoid()
            out_size = 1
        elif task == "multiclassification":
            self.act_out = nn.LogSoftmax(-1)
        elif task == "regression":
            self.act_out = nn.Identity()
        else:
            raise ValueError(
                ("Task must be 'classification', 'multiclassification', 'regression'")
            )

        if out_ranges is not None:
            self.out_ranges = []
            for ran in out_ranges:
                if len(ran) == 2:
                    self.out_ranges.append(np.arange(ran[0], ran[1]))
                elif len(ran) == 3:
                    self.out_ranges.append(np.arange(ran[0], ran[1], ran[2]))
                else:
                    raise ValueError(f"Unknown range {ran}")
            self.out_ranges = np.concatenate(self.out_ranges)
        else:
            self.out_ranges = None
        if out_size is not None:
            self.fc_out = nn.Linear(network.output_size, out_size)
        else:
            self.fc_out = nn.Identity()

    def forward(self, inputs):
        seq_out, emb_outs = self.network(inputs)
        if self.aggregate:
            out = emb_outs
        else:
            out = seq_out
        preds = self.fc_out(out)
        preds = self.act_out(preds.squeeze(-1))
        return preds
    
    def fit(
        self,
        trainset: Dataset,
        validset: Optional[Dataset] = None,
        testset: Optional[Dataset] = None,
        zo_loss: bool = True,
        zo_lambda: float = 1e-4
    ) -> nn.Module:
        """Fit the model to data, if evaluation dataset is offered,
           model selection (early stopping) would be conducted on it.

        Args:
            trainset (Dataset): The training dataset.
            validset (Dataset, optional): The evaluation dataset. Defaults to None.
            testset (Dataset, optional): The test dataset. Defaults to None.

        Returns:
            nn.Module: return the model itself.
        """

        # setup dataset
        trainset.load()
        if validset is not None:
            validset.load()

        loader = DataLoader(
            trainset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=8,
        )

        self._init_scheduler(self.max_epoches, len(loader))
        self.best_params = copy.deepcopy(self.state_dict())
        self.best_network_params = copy.deepcopy(self.network.state_dict())
        iterations = 0
        start_epoch, best_res = self._resume()
        best_epoch = best_res.pop("best_epoch", 0)
        best_score = self.early_stop.best

        # main loop
        for epoch in range(start_epoch, self.max_epoches):
            # pre_epoch
            self.train()
            train_loss = AverageMeter()
            train_global_tracker = GlobalTracker(self.metrics, self.metric_fn)
            start_time = time.time()

            # batch loop
            for data, label in loader:
                # pre batch / fetch data
                if use_cuda():
                    data, label = to_torch(data, device="cuda"), to_torch(
                        label, device="cuda"
                    )

                # forward_once data -> dict ["loss"]
                pred = self(data)

                if self.out_ranges is not None:
                    pred = pred[:, self.out_ranges]
                    label = label[:, self.out_ranges]

                loss = self.loss_fn(label.squeeze(-1), pred.squeeze(-1))
                
                if zo_loss:
                    zo_loss_value = 0
                    L = 0
                    for m in self.modules():
                        if hasattr(m, 'reset'):
                            if isinstance(m, PELIFNode):
                                L += 1
                                zo_loss_value += m.zo_loss_value
                                m.zo_loss_value = 0
                    zo_loss_value = zo_loss_value / L / self.network.T
                    loss += zo_lambda * zo_loss_value
                                
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.parameters(), 1)
                self.optimizer.step()
                loss = loss.item()
                train_loss.update(loss, np.prod(label.shape))
                train_global_tracker.update(label, pred)
                if self.scheduler is not None:
                    self.scheduler.step()
                iterations += 1

                # post batch
                self._post_batch(
                    iterations,
                    epoch,
                    train_loss,
                    train_global_tracker,
                    validset,
                    testset,
                )

            # post epoch
            train_time = time.time() - start_time
            loss = train_loss.performance()  # loss
            # wandb.log({"train_loss": loss})
            start_time = time.time()
            train_global_tracker.concat()
            metric_res = train_global_tracker.performance()
            metric_time = time.time() - start_time
            metric_res["loss"] = loss

            # print log
            # log epoch
            printt(
                f"{epoch}\t'train'\tTime:{train_time:.2f}\tMetricT: {metric_time:.2f}\tlearing rate: {self.optimizer.param_groups[0]['lr']}"
            )
            for metric, value in metric_res.items():
                printt(f"{metric}: {value:.4f}")
            print(f"{datetime.datetime.today()}")
            for k, v in metric_res.items():
                self.writer.add_scalar(f"{k}/train", v, epoch)
            self.writer.flush()

            if validset is not None:
                with torch.no_grad():
                    eval_res = self.evaluate(validset, epoch)
                value = eval_res[self.observe]
                es = self.early_stop.step(value)
                if es == EarlyStopStatus.BEST:
                    best_score = value
                    best_epoch = epoch
                    self.best_params = copy.deepcopy(self.state_dict())
                    self.best_network_params = copy.deepcopy(self.network.state_dict())
                    best_res = {"train": metric_res, "valid": eval_res}
                    torch.save(
                        self.best_params, f"{self.checkpoint_dir}/model_best.pkl"
                    )
                    torch.save(
                        self.best_network_params,
                        f"{self.checkpoint_dir}/network_best.pkl",
                    )
                elif es == EarlyStopStatus.STOP and self._early_stop():
                    break
            else:
                es = self.early_stop.step(metric_res[self.observe])
                if es == EarlyStopStatus.BEST:
                    best_score = metric_res[self.observe]
                    best_epoch = epoch
                    self.best_params = copy.deepcopy(self.state_dict())
                    self.best_network_params = copy.deepcopy(self.network.state_dict())
                    best_res = {"train": metric_res}
                    torch.save(
                        self.best_params, f"{self.checkpoint_dir}/model_best.pkl"
                    )
                    torch.save(
                        self.best_network_params,
                        f"{self.checkpoint_dir}/network_best.pkl",
                    )
                elif es == EarlyStopStatus.STOP and self._early_stop():
                    break
            self._checkpoint(epoch, {**best_res, "best_epoch": best_epoch})

        # release the space of train and valid dataset
        trainset.freeup()
        if validset is not None:
            validset.freeup()

        # finish training, test, save model and write logs
        self._load_weight(self.best_params)
        if testset is not None:
            testset.load()
            print("Begin evaluate on testset ...")
            with torch.no_grad():
                test_res = self.evaluate(testset)
            for k, v in test_res.items():
                self.writer.add_scalar(f"{k}/test", v, epoch)
            value = test_res[self.observe]
            best_score = value
            best_res["test"] = test_res
            testset.freeup()
        torch.save(self.best_params, f"{self.checkpoint_dir}/model_best.pkl")
        torch.save(self.best_network_params, f"{self.checkpoint_dir}/network_best.pkl")
        with open(f"{self.checkpoint_dir}/res.json", "w") as f:
            json.dump(best_res, f, indent=4, sort_keys=True)
        print(best_res)
        keys = list(self.hyper_paras.keys())
        for k in keys:
            if type(self.hyper_paras[k]) not in [int, float, str, bool, torch.Tensor]:
                self.hyper_paras.pop(k)
        self.writer.add_hparams(
            self.hyper_paras, {"result": best_score, "best_epoch": best_epoch}
        )

        return self
    
    
    def test(
        self,
        testset: Optional[Dataset] = None,
        ckpt_dir = None,
    ):
        # finish training, test, save model and write logs
        self._load_weight(torch.load(ckpt_dir))
        testset.load()
        print("Begin evaluate on testset ...")
        with torch.no_grad():
            test_res = self.evaluate(testset)
        testset.freeup()
        print(test_res)
