import os

import numpy as np
import torch
import torch.nn.functional as F
from config import setup_logger
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch_geometric.data import Batch
from torch_geometric.data import Data
from tsp.policy import Policy
from tsp.policy import TSPSurrogate
from tsp.policy import get_cost

import wandb

logger = setup_logger()


class TSPDataset(Dataset):
    def __init__(self, folder_path: str, mode: str, n_node: int, batch_size: int) -> None:
        super().__init__()
        self.folder_path = folder_path
        # 先にloadしておく
        if mode == "training":
            self.data = torch.cat(
                [torch.load(f"{mode}_data/{mode}_data_{n_node}_{batch_size}_{i}") for i in range(10)], 0
            )
        elif mode == "validation":
            self.data = torch.load(f"{mode}_data/{mode}_data_{n_node}_{batch_size}")
        else:
            raise ValueError("")

        # 全部解く
        self.acc_length_list = get_cost(self.data)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        data = self.data[index]
        acc = self.acc_length_list[index]

        return data, acc


class TSPTrainer:
    def __init__(
        self,
        n_nodes: int,
        n_agent: int,
        batch_size: int,
        in_chnl: int,
        hid_chnl: int,
        key_size_embd: int,
        key_size_policy: int,
        val_size: int,
        clipping: int,
        lr_p: float,
        dev: str,
        output: str,
        disable_softmax: bool,
    ) -> None:
        self.device = dev
        self.n_nodes = n_nodes
        self.n_agent = n_agent
        self.batch_size = batch_size
        self.disable_softmax = disable_softmax

        self.train_dataset = TSPDataset(
            folder_path="training_data", mode="training", n_node=n_nodes, batch_size=batch_size
        )
        self.train_dataloader = DataLoader(dataset=self.train_dataset, batch_size=batch_size, shuffle=True)

        self.validation_dataset = TSPDataset(
            folder_path="validation_data", mode="validation", n_node=n_nodes, batch_size=batch_size
        )
        self.validation_dataloader = DataLoader(
            dataset=self.validation_dataset, batch_size=len(self.validation_dataset.data), shuffle=False
        )

        # prepare validation data
        self.validation_data = torch.load("./validation_data/validation_data_" + str(n_nodes) + "_" + str(batch_size))
        # self.validation_data = torch.cat(
        #     [torch.load(f"training_data/training_data_{self.n_nodes}_{batch_size}_{i}") for i in range(10)], 0
        # )
        self.reset()

        self.surrogate = TSPSurrogate(dev=dev, n_nodes=n_nodes)
        self.optim_surrogate = torch.optim.RMSprop(
            self.surrogate.parameters(), lr=lr_p, momentum=0.468, weight_decay=0.067
        )
        self.scheduler_surrogate = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optim_surrogate, min_lr=1e-6, patience=50, factor=0.5, verbose=True
        )

        self.policy = Policy(
            n_nodes=n_nodes,
            in_chnl=in_chnl,
            hid_chnl=hid_chnl,
            n_agent=n_agent,
            key_size_embd=key_size_embd,
            key_size_policy=key_size_policy,
            val_size=val_size,
            clipping=clipping,
            dev=dev,
            disable_softmax=disable_softmax,
        )
        self.saved_model_path = f"{output}/saved_tsp_model"
        os.makedirs(self.saved_model_path, exist_ok=True)
        self.path = (
            lambda x: f"{self.saved_model_path}/iMTSP_{str(n_nodes)}node_{str(n_agent)}agent_{str(batch_size)}batch_{x}itr.pth"
        )
        # self.optim_p = torch.optim.RMSprop(self.policy.parameters(), lr=lr_p, momentum=0.468, weight_decay=0.067)
        self.optim_p = torch.optim.Adam(self.policy.parameters(), lr=lr_p)
        # self.scheduler_p = torch.optim.lr_scheduler.ReduceLROnPlateau(
        #     self.optim_p, min_lr=1e-6, patience=50, factor=0.5, verbose=True
        # )

        # FIXME finetuning
        # # To resume from a breakpoint with wandb. When resuming,
        # do check hyperparameters like learning rate, best validation results
        # if os.path.isfile(path):
        #     policy.load_state_dict(torch.load(path, map_location=torch.device(dev)))
        #     surrogate.load_state_dict(torch.load(path_s, map_location=torch.device(dev)))

    def reset(self) -> None:
        # a large start point
        self.best_so_far = np.inf  # change when resuming
        self.validation_results = []

    def save(self, itr: int) -> None:
        torch.save(self.policy.state_dict(), self.path(itr))
        # torch.save(self.surrogate.state_dict(), self.path(itr))

    def validate(self, itr: int) -> None:
        total_validation_result = 0
        for data, acc_length in self.train_dataloader:
            # pred_length = self.surrogate(data.to(self.device))
            # adj = torch.ones([data.shape[0], data.shape[1], data.shape[1]])  # adjacent matrix fully connected
            # FIXME : dataloaderに入れる
            diff = data[:, :, None, :] - data[:, None, :, :]
            distances = torch.norm(diff, dim=-1)  # 最後の次元でノルム（ユークリッド距離）を計算

            # 距離が epsilon 以下の位置に 1 を持つ隣接行列を作成
            adj = torch.where(distances <= 0.3, torch.tensor(1.0), torch.tensor(0.0))
            # adj = torch.zeros([data.shape[0], data.shape[1], data.shape[1]])  # adjacent matrix fully connected
            # for i in range(self.batch_size):
            #     for j in range(self.n_nodes):
            #         for k in range(j + 1, self.n_nodes):
            #             dist = torch.norm(data[i, j] - data[i, k])
            #             if dist <= 0.3:
            #                 adj[i, j, k] = 1
            #                 adj[i, k, j] = 1

            data_list = [
                Data(x=data[i], edge_index=torch.nonzero(adj[i], as_tuple=False).t()) for i in range(data.shape[0])
            ]
            batch_graph = Batch.from_data_list(data_list=data_list).to(self.device)

            # batch_size x n_agent x n_location (w.o. depot)
            pred_length = self.policy(batch_graph, n_nodes=data.shape[1], n_batch=self.batch_size)
            # pred_length = self.surrogate(batch_graph, n_nodes=data.shape[1], n_batch=self.batch_size)

            # lossの計算
            validation_result = F.mse_loss(
                pred_length, torch.tensor(acc_length, device=self.device, dtype=torch.float32)
            )
            total_validation_result += validation_result
        wandb.log({"best val so far": total_validation_result})
        if total_validation_result < self.best_so_far:
            self.save(itr)
            logger.info(
                f"Found better policy, and the validation result is: {total_validation_result}",
            )
            self.validation_results.append(total_validation_result)
            self.best_so_far = total_validation_result
        print(f"current : {total_validation_result} | best : {self.best_so_far}")

    def train_one_batch(self) -> None:
        for data, acc_length in self.train_dataloader:
            # pred_length = self.surrogate(data.to(self.device))

            adj = torch.ones([data.shape[0], data.shape[1], data.shape[1]])  # adjacent matrix fully connected
            data_list = [
                Data(x=data[i], edge_index=torch.nonzero(adj[i], as_tuple=False).t()) for i in range(data.shape[0])
            ]
            batch_graph = Batch.from_data_list(data_list=data_list).to(self.device)

            # get pi
            # batch_size x n_agent x n_location (w.o. depot)
            pred_length = self.policy(batch_graph, n_nodes=data.shape[1], n_batch=self.batch_size)
            acc_length = torch.tensor(acc_length, device=self.device, dtype=torch.float32)

            # lossの計算
            loss = F.mse_loss(pred_length, acc_length)
            self.optim_p.zero_grad()
            # self.optim_surrogate.zero_grad()
            loss.backward()
            # Optimize the surrogate net
            # self.optim_surrogate.step()
            self.optim_p.step()
            wandb.log({"average gap": sum(abs(pred_length - acc_length)) / self.batch_size})
            wandb.log({"cost": loss / self.batch_size})

    def __call__(self, iterations):
        self.reset()
        for itr in range(iterations):
            self.train_one_batch()
            if (itr + 1) % 1 == 0 or len(self.validation_results) == 0:
                logger.info(f"current iteration : {itr}")
                self.validate(itr)
        self.save(itr)
        logger.info("finish")
