import os
import random
from pickle import load

random.seed(0)

import numpy as np
import ot
import torch
from config import setup_logger
from LinSATNet import linsat_layer
from mm_cvrp.policy import Policy
from mm_cvrp.policy import Surrogate
from mm_cvrp.policy import action_sample
from mm_cvrp.policy import get_cost
from mm_cvrp.utils import Loss
from torch.distributions import Categorical
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch_geometric.data import Batch
from torch_geometric.data import Data

import wandb

logger = setup_logger()


class MapDataset(Dataset):
    def __init__(self, folder_path: str, mode: str, return_depot: bool = False, augmentation: bool = True) -> None:
        super().__init__()
        self.return_depot = return_depot
        self.augmentation = augmentation

        with open(f"{folder_path}/{mode}_location.pickle", "rb") as f:
            self.locations = load(f)

    def __len__(self) -> int:
        return len(self.locations)

    def __getitem__(self, index: int):
        location = self.locations[index]
        depot = location[0]
        if self.augmentation:
            if random.random() >= 0.5:
                location[:, 0] = 1 - location[:, 0]
            if random.random() >= 0.5:
                location[:, 1] = 1 - location[:, 1]

        location = location - depot
        location = np.array(location, dtype=np.float32)

        if self.return_depot:
            return location, depot
        else:
            return location


class Trainer:
    def __init__(
        self,
        n_nodes: int,
        n_agent: int,
        batch_size: int,
        in_chnl: int,
        hid_chnl: int,
        key_size_embd: int,
        key_size_policy: int,
        val_size: int,
        clipping: int,
        lr_p: float,
        lr_s: float,
        dev: str,
        src_vector: torch.Tensor,
        output: str,
        loss: Loss,
        capacity: int,
        disable_softmax: bool,
        train_folder: str,
        validation_folder: str,
        augmentation: bool = True,
    ) -> None:
        self.device = dev
        self.n_nodes = n_nodes
        self.n_agent = n_agent
        self.batch_size = batch_size
        self.src_vector = src_vector
        self.loss = loss
        self.disable_softmax = disable_softmax
        self.capacity = capacity

        # prepare validation data
        self.train_dataset = MapDataset(folder_path=train_folder, mode="train", augmentation=augmentation)
        self.train_dataloader = DataLoader(
            dataset=self.train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
        )
        self.validation_dataset = MapDataset(folder_path=validation_folder, mode="validation", augmentation=False)
        self.validation_dataloader = DataLoader(dataset=self.validation_dataset, batch_size=batch_size, shuffle=False)
        # self.validation_data_dummy = torch.load(
        #     "./training_data/training_data_" + str(n_nodes) + "_" + str(batch_size)
        # )
        self.reset()

        self.policy = Policy(
            in_chnl=in_chnl,
            hid_chnl=hid_chnl,
            n_agent=n_agent,
            key_size_embd=key_size_embd,
            key_size_policy=key_size_policy,
            val_size=val_size,
            clipping=clipping,
            dev=dev,
            disable_softmax=disable_softmax,
        )
        self.saved_model_path = f"{output}/saved_model_{self.loss}_softmax{not self.disable_softmax}"
        os.makedirs(self.saved_model_path, exist_ok=True)
        self.path = (
            lambda x: f"{self.saved_model_path}/iMTSP_{str(n_nodes)}node_{str(n_agent)}agent_{str(batch_size)}batch_{x}itr.pth"
        )
        self.optim_p = torch.optim.RMSprop(self.policy.parameters(), lr=lr_p, momentum=0.468, weight_decay=0.067)
        self.scheduler_p = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optim_p, min_lr=1e-6, patience=50, factor=0.5, verbose=True
        )

        # self.surrogate = Surrogate(in_dim=n_nodes - 1, out_dim=n_agent, n_hidden=256, nonlin="tanh", dev=dev)
        self.surrogate = Surrogate(in_dim=n_nodes * n_agent * 2, out_dim=n_agent, n_hidden=256, nonlin="tanh", dev=dev)
        self.saved_surrogate_model_path = (
            f"{output}/saved_surrogate_model_{self.loss}_softmax{not self.disable_softmax}"
        )
        os.makedirs(self.saved_surrogate_model_path, exist_ok=True)
        self.path_s = (
            lambda x: f"{self.saved_surrogate_model_path}/iMTSP_{str(n_nodes)}node_{str(n_agent)}agent_{str(batch_size)}batch_{x}itr.pth"
        )
        self.optim_s = torch.optim.RMSprop(self.surrogate.parameters(), lr=lr_s, momentum=0.202, weight_decay=0.336)
        self.scheduler_s = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optim_s, min_lr=1e-6, patience=50, factor=0.5, verbose=True
        )

        # FIXME finetuning
        # # To resume from a breakpoint with wandb. When resuming,
        # do check hyperparameters like learning rate, best validation results
        # if os.path.isfile(path):
        #     policy.load_state_dict(torch.load(path, map_location=torch.device(dev)))
        #     surrogate.load_state_dict(torch.load(path_s, map_location=torch.device(dev)))

    def reset(self) -> None:
        # a large start point
        self.best_so_far = np.inf  # change when resuming
        self.validation_results = []

    def validate(self, itr: int) -> None:
        # validation_result = validate(self.validation_data, self.policy, self.n_agent, self.device, self.src_vector)
        def calculate_optimal_transport(pi: torch.Tensor, n_nodes) -> torch.Tensor:
            pi_tmp = pi.reshape(len(pi), -1)

            # capacity constraint
            A = torch.zeros([self.n_agent, self.n_agent * (n_nodes - 1)], device=self.device)
            b = torch.tensor([25] * self.n_agent, dtype=torch.float32, device=self.device)
            for i in range(self.n_agent):
                A[i, i * (n_nodes - 1) : (i + 1) * (n_nodes - 1)] = 1

            # １つは必ずアサインされる(総和が１)
            # the column constrain
            E = torch.zeros([n_nodes - 1, self.n_agent * (n_nodes - 1)], device=self.device)
            column_gap = (n_nodes - 1) * torch.arange(self.n_agent)
            for i in range(n_nodes - 1):
                E[i, i + column_gap] = 1
            f = torch.ones(n_nodes - 1, dtype=torch.float32, device=self.device)

            E2 = torch.zeros([2, self.n_agent * (n_nodes - 1)], device=self.device)
            # 0番目のagentに対して前半はアサイン不可能
            E2[0, : n_nodes // 2] = 1
            # 1番目は後半に対してアサイン不可能
            E2[1, (n_nodes - 1) * 2 - n_nodes // 2 : (n_nodes - 1) * 2 + 1] = 1
            f2 = torch.zeros(2, dtype=torch.float32, device=self.device)
            E = torch.concat([E, E2])
            f = torch.concat([f, f2])

            output = linsat_layer(pi_tmp, A=A, b=b, E=E, f=f, tau=1e-4, max_iter=800)
            result = output.reshape(batch_size, self.n_agent, n_nodes - 1)
            return result

            output = None
            # 各地点の割当先を求めたいので、１が並んだベクトルを用意
            dst_vector = torch.ones(pi.shape[2], dtype=torch.float32, device=self.device)
            for i in range(pi.shape[0]):
                target = 1 - pi[i]
                P = ot.emd(src_vector, dst_vector, target)
                if abs(P.sum() - pi.shape[2]) >= 0.1:
                    breakpoint()
                if output is None:
                    output = P.unsqueeze(0)
                else:
                    output = torch.cat((output, P.unsqueeze(0)), 0)

            return output

        validation_result = 0
        for instances in self.validation_dataloader:
            batch_size = instances.shape[0]
            adj = torch.ones([batch_size, instances.shape[1], instances.shape[1]])  # adjacent matrix

            # get batch graphs instances list
            instances_list = [Data(x=instances[i], edge_index=torch.nonzero(adj[i]).t()) for i in range(batch_size)]
            # generate batch graph
            batch_graph = Batch.from_data_list(data_list=instances_list).to(self.device)

            # get pi
            pi = self.policy(batch_graph, n_nodes=instances.shape[1], n_batch=batch_size)
            n_nodes = instances.shape[1]

            pi = calculate_optimal_transport(pi, n_nodes)
            # sample action and calculate log probs
            action, log_prob = action_sample(pi)

            # get reward for each batch
            reward, _ = get_cost(action, instances, self.n_agent)  # reward: tensor [batch, 1]
            # print('Validation result:', format(sum(reward)/batch_size, '.4f'))

            validation_result += sum(reward)
        validation_result /= self.validation_dataset.__len__()

        wandb.log({"best val so far": validation_result})
        if validation_result < self.best_so_far:
            torch.save(self.policy.state_dict(), self.path(itr))
            torch.save(self.surrogate.state_dict(), self.path_s(itr))
            logger.info(
                f"Found better policy, and the validation result is: {validation_result}",
            )
            self.validation_results.append(validation_result)
            self.best_so_far = validation_result

    # def calculate_optimal_transport(self, pi: torch.Tensor, src_vector: torch.Tensor, dst_vector: torch.Tensor):
    def calculate_optimal_transport(self, pi: torch.Tensor, method: str = "linsatnet") -> torch.Tensor:
        def run_linsatnet():
            pi_tmp = pi.reshape(len(pi), -1)

            # capacity constraint
            A = torch.zeros([self.n_agent, self.n_agent * (self.n_nodes - 1)], device=self.device)
            b = torch.tensor([self.capacity] * self.n_agent, dtype=torch.float32, device=self.device)
            for i in range(self.n_agent):
                A[i, i * (self.n_nodes - 1) : (i + 1) * (self.n_nodes - 1)] = 1

            # １つは必ずアサインされる(総和が１)
            # the column constrain
            E = torch.zeros([self.n_nodes - 1, self.n_agent * (self.n_nodes - 1)], device=self.device)
            column_gap = (self.n_nodes - 1) * torch.arange(self.n_agent)
            for i in range(self.n_nodes - 1):
                E[i, i + column_gap] = 1
            f = torch.ones(self.n_nodes - 1, dtype=torch.float32, device=self.device)

            # E2 = torch.zeros([2, self.n_agent * (self.n_nodes - 1)], device=self.device)
            # # 0番目のagentに対して前半はアサイン不可能
            # E2[0, : self.n_nodes // 2] = 1
            # # 1番目は後半に対してアサイン不可能
            # E2[1, (self.n_nodes - 1) * 2 - self.n_nodes // 2 : (self.n_nodes - 1) * 2 + 1] = 1
            # f2 = torch.zeros(2, dtype=torch.float32, device=self.device)
            # E = torch.concat([E, E2])
            # f = torch.concat([f, f2])

            output = linsat_layer(pi_tmp, A=A, b=b, E=E, f=f, tau=1e-4, max_iter=800)
            result = output.reshape(self.batch_size, self.n_agent, self.n_nodes - 1)

            return result

        # どれか１つにアサインされる
        def run_original_ot():
            output = None
            # 各地点の割当先を求めたいので、１が並んだベクトルを用意
            dst_vector = torch.ones(pi.shape[2], dtype=torch.float32, device=self.device)
            for i in range(pi.shape[0]):
                target = 1 - pi[i]
                P = ot.emd(self.src_vector, dst_vector, target)
                if output is None:
                    output = P.unsqueeze(0)
                else:
                    output = torch.cat((output, P.unsqueeze(0)), 0)

            return output

        if method == "linsatnet":
            return run_linsatnet()
        elif method == "original":
            return run_original_ot()
        else:
            raise ValueError("specify eithr linsatnet or original")

    def train_one_epoch(self, itr: int):
        c = 0
        for data in self.train_dataloader:
            c += len(data)
            print(itr, c, len(self.train_dataset.locations), round(c / len(self.train_dataset.locations), 3))
            # prepare training data
            # data = torch.load(
            #     f"./training_data/training_data_{str(self.n_nodes)}_{str(self.batch_size)}_{str(itr % 10)}"
            # )  # [batch, nodes, fea], fea is 2D location
            adj = torch.ones([data.shape[0], data.shape[1], data.shape[1]])  # adjacent matrix fully connected
            data_list = [
                Data(x=data[i], edge_index=torch.nonzero(adj[i], as_tuple=False).t()) for i in range(data.shape[0])
            ]
            batch_graph = Batch.from_data_list(data_list=data_list).to(self.device)

            # get pi
            # batch_size x n_agent x n_location (w.o. depot)
            pi = self.policy(batch_graph, n_nodes=data.shape[1], n_batch=self.batch_size)
            if self.loss != Loss.iMTSP.value:
                pi2 = self.calculate_optimal_transport(pi)

            # sample action and calculate log probabilities
            action, log_prob = action_sample(pi2)

            # 0割り避け
            pi = pi + 1e-8
            norm_pi = pi / pi.sum(dim=1, keepdim=True)
            dist = Categorical(norm_pi.transpose(2, 1))
            log_prob = dist.log_prob(action)
            log_prob_sum = log_prob.sum(dim=1)

            # get real cost for each batch
            # data : batch_size x n_location (w depot) x 2 (x,y)
            cost, cost_list = get_cost(
                action, data, self.n_agent
            )  # cost: tensor [batorch.cat([torch.reshape(p, [-1]) for p in pg_grads], 0)tch, 1]
            # estimate cost via the surrogate network
            # pi2 : batch x vehicle x (node-1)
            # data : batch x node x 2(x,y)
            # data2 : batch x node x 2(x,y) x n_agent
            data2 = data.unsqueeze(-1).expand(-1, -1, -1, self.n_agent).to(self.device)
            pi2 = pi2.transpose(2, 1).unsqueeze(2).expand(-1, -1, 2, -1)
            box = torch.ones([self.batch_size, 1, 2, self.n_agent], device=self.device)
            pi2 = torch.cat([box, pi2], dim=1)
            input_surrogate = data2 * pi2
            # cost_s = torch.squeeze(self.surrogate(log_prob))
            cost_s = torch.squeeze(self.surrogate(input_surrogate))
            # compute loss, need to freeze surrogate's parameters, cost_s in the second term should be detached
            loss_surrogate = (torch.tensor(cost_list, device=self.device) - cost_s).square().sum()
            loss_shortest = cost_s.square().sum()
            # loss_cost_prob = torch.mul(torch.tensor(cost, device=self.device), log_prob.sum(dim=1)).sum()
            loss_cost_prob = torch.mul(torch.max(cost_s, 1).values, -log_prob.sum(dim=1)).sum()
            loss_prob_gap = torch.distributions.Categorical(pi).entropy().sum()
            wandb.log({"loss_surrogate": loss_surrogate})
            wandb.log({"loss_shortest": loss_shortest})
            wandb.log({"loss_cost_prob": loss_cost_prob})
            if self.loss == Loss.proposed.value:
                loss = loss_surrogate * 10**3 + loss_shortest * 10**3
            elif self.loss == Loss.proposed2.value:
                loss = torch.mul(torch.tensor(cost, device=self.device), log_prob.sum(dim=1)).sum() + loss_surrogate
            elif self.loss == Loss.proposed3.value:
                loss = loss_surrogate + loss_cost_prob
            elif self.loss == Loss.proposed4.value:
                loss = loss_surrogate + loss_cost_prob + loss_prob_gap
            elif self.loss == Loss.proposed5.value:
                loss = loss_surrogate
            elif self.loss == Loss.iMTSP.value:
                loss = (
                    torch.mul(torch.tensor(cost, device=self.device) - 2, log_prob.sum(dim=1)).sum()
                    - torch.mul(cost_s.detach() - 2, log_prob.sum(dim=1)).sum()
                    + (cost_s - 2).sum()
                )
            elif self.loss == Loss.iMTSP2.value:
                loss = (
                    torch.mul(torch.tensor(cost, device=self.device), log_prob.sum(dim=1)).sum()
                    - torch.mul(cost_s.detach(), log_prob.sum(dim=1)).sum()
                    + (cost_s).sum()
                )
            else:
                raise ValueError("")

            # self.scheduler_p.step(torch.tensor(cost, device=self.device).sum())
            # self.scheduler_s.step(torch.tensor(cost, device=self.device).sum())
            # compute gradient's variance loss w.r.t. surrogate's parameter
            # grad_p = torch.autograd.grad(
            #     loss,
            #     self.policy.parameters(),
            #     grad_outputs=torch.ones_like(loss),
            #     create_graph=True,
            #     retain_graph=True,
            # )
            # grad_temp = torch.cat([torch.reshape(p, [-1]) for p in grad_p], 0)
            # grad_ps = torch.square(grad_temp).mean(0)
            # wandb.log({"variance": grad_ps})
            # grad_s = torch.autograd.grad(
            #     grad_ps,
            #     self.surrogate.parameters(),
            #     grad_outputs=torch.ones_like(grad_ps),
            #     retain_graph=True,
            #     allow_unused=True,
            # )
            # Optimize the policy net
            self.optim_p.zero_grad()
            self.optim_s.zero_grad()
            loss.backward()
            self.optim_p.step()
            self.optim_s.step()
            # Optimize the surrogate net
            # for params, grad in zip(self.surrogate.parameters(), grad_s, strict=False):
            #     params.grad = grad
            # if itr % 100 == 0:
            # logger.info(f"Iteration:{itr}")
            wandb.log({"cost": sum(cost) / self.batch_size})
            wandb.log({"diff of cost": (sum(cost) - sum(cost_s).detach()) / self.batch_size})

    def __call__(self, iterations):
        self.reset()
        for itr in range(iterations):
            self.train_one_epoch(itr)
            # if (itr + 1) % 10 == 0 or len(self.validation_results) == 0:
            logger.info(f"current iteration : {itr}")
            self.validate(itr)
        logger.info("finish")
