import os
import gc
import torch
from typing import Any, Union, List, Tuple, Dict
from torch import Tensor, nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
from tqdm import trange
from ml4co_kit import to_numpy
from meta_diffusion.env.env import MetaDiffEnv
from meta_diffusion.model.decoder import *
from src.meta_diffusion.model.encoder import GNNEncoder
from meta_diffusion.model.consistency.inference import InferenceSchedule
from meta_diffusion.model.consistency.diffusion import CategoricalDiffusion
from torch.nn.utils.convert_parameters import vector_to_parameters, parameters_to_vector
from matplotlib import pyplot as plt


class MetaDiffModel(object):
    def __init__(
        self,
        env: MetaDiffEnv,
        encoder: GNNEncoder,
        decoder: MetaDiffDecoder,
        train_outer_steps: int = 100,
        train_inner_steps: int = 1,
        train_inner_samples: int = 4,
        val_inner_steps: int = 0,
        inner_lr: float = 5e-5,
        outer_lr: float = 1e-3,
        beam_size: int = -1,
        cm_alpha: float = 0.2,
        cm_beta: float = 0.025,
        inference_steps: int = 1,
        save_n_epochs: int = 10,
        save_path: str = "path/to/save/weights",
        plot_folder: str = None,
        weight_path: str = None,
        grad_norm: bool = True,
        use_pcgrad: bool = False,
        enable_meta: bool = True
    ):
        self.env: MetaDiffEnv = env
        self.model: GNNEncoder = encoder
        self.decoder: Dict[str, MetaDiffDecoder] = decoder
        self.beam_size = beam_size
        self.energy_finetune = (len(env.task_pool) == 1 and task[0] == "MCut")
        self.train_outer_steps = train_outer_steps
        self.train_inner_steps = train_inner_steps
        self.train_inner_samples = train_inner_samples
        self.val_inner_steps = val_inner_steps
        self.save_n_epochs = save_n_epochs
        self.save_path = save_path
        self.plot_folder = plot_folder
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.device = self.env.device
        self.task_pool = self.env.task_pool
        self.grad_norm = grad_norm
        self.use_pcgrad = use_pcgrad
        self.train_inner_steps = train_inner_steps
        self.enable_meta = enable_meta

        self.diffusion = CategoricalDiffusion(T=1000, schedule="linear")
        self.cm_alpha = cm_alpha
        self.cm_beta = cm_beta
        self.inference_steps = inference_steps
        self.time_schedule = InferenceSchedule("cosine", 1000, inference_steps)
        
        # load pretrained weights if needed
        if weight_path is not None:
            self.state_dict = torch.load(weight_path, map_location=self.device)
            self.model.load_state_dict(self.state_dict, strict=False)
        self.model.to(self.device)

        self.best_val_obj = dict()
        for task in self.task_pool:
            if task in ["MIS", "MCl", "MCut"]:
                self.best_val_obj[task] = 0
            elif task in ["TSP", "ATSP", "MVC"]:
                self.best_val_obj[task] = 1e6
        self.best_val_obj.update({"avg_gap": 1e6})

        self.logger = {task: list() for task in self.task_pool}
        self.logger.update({"meta_loss": []})
        self.logger.update({"avg_gap": []})
        if self.plot_folder is not None:
            os.makedirs(self.plot_folder, exist_ok=True)

        self.outer_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.outer_lr, weight_decay=1e-5)
        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.outer_optimizer, T_max=self.train_outer_steps, eta_min=0.0)

    def model_train(self):
        self.model.train()
        tbar = trange(self.train_outer_steps)
        for outstep in tbar:
            if self.enable_meta:
                losses = None
                for idx in range(self.train_inner_samples):
                    task = np.random.choice(self.task_pool)
                    self.model.zero_grad(set_to_none=True)
                    old_params = parameters_to_vector(self.model.parameters())
                    batch = self.env.generate_train_data(task, self.env.train_batch_size)
                    loss = self.get_loss_train(task, batch)

                    grads = torch.autograd.grad(loss, self.model.parameters(), retain_graph=True, create_graph=True, allow_unused=True)
                    valid_grads = [grad if grad is not None else torch.zeros_like(param) for grad, param in zip(grads, self.model.parameters())]

                    if self.grad_norm:
                        grad_norm = torch.norm(parameters_to_vector(valid_grads), p=2)
                        valid_grads = [g / (grad_norm + 1e-8) for g in valid_grads]                    

                    new_params = parameters_to_vector(self.model.parameters()) - self.inner_lr * parameters_to_vector(valid_grads)
                    vector_to_parameters(new_params, self.model.parameters())
                    new_loss = self.get_loss_train(task, batch)
                    losses = new_loss.reshape(-1, 1) if losses is None else torch.cat((losses, new_loss.reshape(-1, 1)), dim=0)
                    vector_to_parameters(old_params, self.model.parameters())

                meta_loss = torch.mean(losses)
                self.outer_optimizer.zero_grad()
                meta_loss.backward()
                self.outer_optimizer.step()
                self.lr_scheduler.step()
                tbar.set_description(f"Training loss: {meta_loss.item()}")
                self.logger["meta_loss"].append(meta_loss.item())

                del grads, valid_grads, batch, losses, old_params, new_params
                gc.collect()
                torch.cuda.empty_cache()

            else:
                raise NotImplementedError()
        
            # validation
            if outstep % self.save_n_epochs == 0:
                self.model_eval(outstep)   
            
            if outstep % (self.save_n_epochs * 20) == 0:
                os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
                torch.save(self.model.state_dict(), self.save_path.format(outstep, "log"))

    
    def model_eval(self, outer_step):
        self.model.eval()
        overall_gaps = list()
        overall_objs = list()
        for task in self.task_pool:
            gaps = list()
            objs = list()
            for idx in range(len(self.env.val_dataset)):
                batch = self.env.generate_val_data(task, idx)
                with torch.no_grad():
                    obj, gap = self.get_sol_val(task, batch)
                objs.append(obj.item())
                gaps.append(gap.item())
            obj_mean = np.array(objs).mean()
            gap_mean = np.array(gaps).mean()
            overall_gaps.append(gap_mean)
            overall_objs.append(obj_mean)
            if (
                (task in ["MIS", "MCl", "MCut"] and obj_mean > self.best_val_obj[task]) or \
                (task in ["TSP", "ATSP", "MVC"] and obj_mean < self.best_val_obj[task])
            ):
                self.best_val_obj[task] = obj_mean
            print(f"{task}: {obj_mean:.3f}, best: {self.best_val_obj[task]:.3f}, gap: {gap_mean:.3f}")
            self.logger[task].append(obj_mean)
        avg_overall_gap = np.array(overall_gaps).mean()
        avg_overall_obj = np.array(overall_objs).mean()
        if not np.isnan(avg_overall_gap) and not np.isinf(avg_overall_gap):
            if avg_overall_gap < self.best_val_obj["avg_gap"]:
                os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
                torch.save(self.model.state_dict(), self.save_path.format(outer_step, f"{avg_overall_gap:.4f}"))
                self.best_val_obj["avg_gap"] = avg_overall_gap
        else: # for HCP and SAT
            if avg_overall_obj == self.best_val_obj["ATSP"]:
                os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
                torch.save(self.model.state_dict(), self.save_path.format(outer_step, f"{avg_overall_obj:.4f}"))
                self.best_val_obj["avg_gap"] = avg_overall_gap
        self.logger["avg_gap"].append(avg_overall_gap)
        if self.plot_folder is not None:
            self.plot_log()

    def get_loss_train(self, task: str, batch_data: Any):
        if task in ["TSP", "ATSP"]:
            if self.env.sparse:
                loss = self.train_edge_sparse_process(*batch_data)
            else:
                loss = self.train_edge_dense_process(*batch_data)
        elif task in ["MIS", "MCut", "MCl", "MVC"]:
            if self.env.sparse:
                loss = self.train_node_sparse_process(*batch_data)
            else:
                loss = self.train_node_dense_process(*batch_data)
        else:
            raise NotImplementedError()
        return loss
    
    def get_sol_val(self, task: str, batch_data: Any, plot_heatmap: bool = False):
        self.env.mode = "val"
        if task in ["TSP", "ATSP"]:
            if self.env.sparse:
                loss, heatmap = self.inference_edge_sparse_process(*batch_data)
            else:
                loss, heatmap = self.inference_edge_dense_process(*batch_data)
            
            if plot_heatmap:
                raise NotImplementedError()
                
        elif task in ["MIS", "MCut", "MCl", "MVC"]:
            if self.env.sparse:
                loss, heatmap = self.inference_node_sparse_process(*batch_data)
            else:
                loss, heatmap = self.inference_node_dense_process(*batch_data)
        else:
            raise NotImplementedError()
        
        # decoding
        if self.env.sparse:
            costs_avg, gap_avg = self.decoder[task].sparse_decode(heatmap, *batch_data, return_cost=True)
        else:
            costs_avg, gap_avg = self.decoder[task].dense_decode(heatmap, *batch_data, return_cost=True)
        
        self.env.mode = "train"

        return costs_avg, gap_avg

    def train_edge_sparse_process(
        self, task: str, nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, edge_index: Tensor, graph_list: List[Tensor], ground_truth: Tensor, 
        nodes_num_list: list, edges_num_list: list, ref_tour_list: list
    ) -> Tensor:
        # consistency time
        t1: Tensor = torch.randint(1, 1001, size=(1,), device=self.device)
        t2 = (self.cm_alpha * t1).int()
        
        # diffusion (add noise based on ground truth)
        e_noised_t1, e_noised_t2 = self.diffusion.consistency_sample_sparse(
            x=ground_truth, t1=t1, t2=t2
        )

        # add small random noise
        e_small_noise_t1 = 1.0 + self.cm_beta * torch.rand_like(e_noised_t1)
        e_small_noise_t2 = 1.0 + self.cm_beta * torch.rand_like(e_noised_t2)
        e_noised_t1 = (2 * e_noised_t1 - 1) * e_small_noise_t1
        e_noised_t2 = (2 * e_noised_t2 - 1) * e_small_noise_t2

        # forward
        x_pred_t1, e_pred_t1 = self.model.forward(
            task=task, focus_on_node=False, focus_on_edge=True, 
            nodes_feature=nodes_feature, x=x, edges_feature=edges_feature, 
            e=e_noised_t1, t=t1, edge_index=edge_index
        )
        x_pred_t2, e_pred_t2 = self.model.forward(
            task=task, focus_on_node=False, focus_on_edge=True, 
            nodes_feature=nodes_feature, x=x, edges_feature=edges_feature, 
            e=e_noised_t2, t=t2, edge_index=edge_index
        )
        del x_pred_t1, x_pred_t2

        # loss
        if self.energy_finetune:
            loss_t1 = self.env.finetune_sparse(e_pred_t1, edges_feature, edge_index)
            loss_t2 = self.env.finetune_sparse(e_pred_t2, edges_feature, edge_index)
        else:
            loss_func = nn.CrossEntropyLoss()
            loss_t1 = loss_func(e_pred_t1, ground_truth)
            loss_t2 = loss_func(e_pred_t2, ground_truth)
        loss = loss_t1 + loss_t2
        return loss
   
    def train_edge_dense_process(
        self, task: str, nodes_feature: Tensor, x: Tensor, graph: Tensor, 
        e: Tensor, ground_truth: Tensor, nodes_num_list: list, ref_tour_list: list
    ) -> Tensor:
        # consistency time
        t1: Tensor = torch.randint(1, 1001, size=(1,), device=self.device)
        t2 = (self.cm_alpha * t1).int()
        
        # diffusion (add noise based on ground truth)
        e_noised_t1, e_noised_t2 = self.diffusion.consistency_sample_dense(
            x=ground_truth, t1=t1, t2=t2
        )
        
        # add small random noise
        e_small_noise_t1 = 1.0 + self.cm_beta * torch.rand_like(e_noised_t1)
        e_small_noise_t2 = 1.0 + self.cm_beta * torch.rand_like(e_noised_t2)
        e_noised_t1 = e_noised_t1 * e_small_noise_t1
        e_noised_t2 = e_noised_t2 * e_small_noise_t2

        # update label
        e_noised_t1_mask = e_noised_t1-0.5
        e_noised_t2_mask = e_noised_t2-0.5

        # forward
        x_pred_t1, e_pred_t1 = self.model.forward(
            task=task, focus_on_node=False, focus_on_edge=True, 
            nodes_feature=nodes_feature, x=x, edges_feature=graph, 
            e=e_noised_t1_mask, t=t1, edge_index=None
        )
        x_pred_t2, e_pred_t2 = self.model.forward(
            task=task, focus_on_node=False, focus_on_edge=True, 
            nodes_feature=nodes_feature, x=x, edges_feature=graph, 
            e=e_noised_t2_mask, t=t2, edge_index=None
        )
        del x_pred_t1, x_pred_t2
        
        # loss
        if self.energy_finetune:
            loss_t1 = self.env.finetune_dense(e_pred_t1, graph)
            loss_t2 = self.env.finetune_dense(e_pred_t2, graph)
        else:
            loss_func = nn.CrossEntropyLoss()
            loss_t1 = loss_func(e_pred_t1, ground_truth)
            loss_t2 = loss_func(e_pred_t2, ground_truth)
        loss = loss_t1 + loss_t2
        return loss
    
    def train_node_sparse_process(
        self, task: str, nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, edge_index: Tensor, graph_list: List[Tensor], ground_truth: Tensor,
        nodes_num_list: list, edges_num_list: list, ref_tour_list: list
    ) -> Tensor:        
        # consistency time
        t1: Tensor = torch.randint(1, 1001, size=(1,)).to(self.device)
        t2 = (self.cm_alpha * t1).int()
        
        # diffusion (add noise based on ground truth)
        x_noised_t1, x_noised_t2 = self.diffusion.consistency_sample_sparse(
            x=ground_truth, t1=t1, t2=t2
        )
        
        # add small random noise
        x_small_noise_t1 = 1.0 + self.cm_beta * torch.rand_like(x_noised_t1)
        x_small_noise_t2 = 1.0 + self.cm_beta * torch.rand_like(x_noised_t2)
        x_noised_t1 = x_noised_t1 * x_small_noise_t1 - 0.5
        x_noised_t2 = x_noised_t2 * x_small_noise_t2 - 0.5
        
        # forward
        x_pred_t1, e_pred_t1 = self.model.forward(
            task=task, focus_on_node=True, focus_on_edge=False,
            nodes_feature=nodes_feature, x=x_noised_t1, 
            edges_feature=edges_feature, e=e, t=t1, edge_index=edge_index
        )
        x_pred_t2, e_pred_t2 = self.model.forward(
            task=task, focus_on_node=True, focus_on_edge=False,
            nodes_feature=nodes_feature, x=x_noised_t2,
            edges_feature=edges_feature, e=e, t=t2, edge_index=edge_index
        )
        del e_pred_t1, e_pred_t2
        
        # loss
        if self.energy_finetune:
            loss_t1 = self.env.finetune_sparse(task, x_pred_t1, edges_feature, edge_index)
            loss_t2 = self.env.finetune_sparse(task, x_pred_t2, edges_feature, edge_index)
        else:
            loss_func = nn.CrossEntropyLoss()
            loss_t1 = loss_func(x_pred_t1, ground_truth)
            loss_t2 = loss_func(x_pred_t2, ground_truth)
        loss = loss_t1 + loss_t2
        
        return loss

    def train_node_dense_process(
        self, task: str, nodes_feature: Tensor, x: Tensor, graph: Tensor, 
        e: Tensor, ground_truth: Tensor, nodes_num_list: list, ref_tour_list: list
    ) -> Tensor:
        raise NotImplementedError()

    def inference_edge_sparse_process(
        self, task: str, nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, edge_index: Tensor, graph_list: List[Tensor], ground_truth: Tensor,
        nodes_num_list: list, edges_num_list: list, ref_tour_list: list
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
       
        # random init
        e_noised = (torch.randn_like(e) > 0).float()
        
        # denoise steps
        for i in range(self.inference_steps):
            # consistency time
            t1, t2 = self.time_schedule(i)
            t1 = torch.tensor([t1], device=self.device).float()
            
            # add small random noise
            e_small_noise = 1.0 + self.cm_beta * torch.rand_like(e_noised)
            e_noised = e_noised * e_small_noise - 0.5
            
            # forward
            x_pred, e_pred = self.model.forward(
                task=task, focus_on_node=False, focus_on_edge=True, 
                nodes_feature=nodes_feature, x=x, edges_feature=edges_feature, 
                e=e_noised, t=t1, edge_index=edge_index
            )
            del x_pred
            
            # softmax
            e_pred_softmax = e_pred.softmax(-1)
            
            # t2
            if t2 != 0:
                e_pred_ber = torch.bernoulli(e_pred_softmax[..., 1].clamp(0, 1))
                e_pred_ber_onehot: Tensor = F.one_hot(e_pred_ber.long(), num_classes=2)
                Q_bar = torch.from_numpy(self.diffusion.Q_bar[t2]).float().to(self.device)
                e_prob = torch.matmul(e_pred_ber_onehot.float(), Q_bar)
                e_noised = torch.bernoulli(e_prob[..., 1].clamp(0, 1))

        # heatmap
        e_heatmap = e_pred_softmax[:, 1]

        # return
        if self.env.mode == "val":
            loss = nn.CrossEntropyLoss()(e_pred, ground_truth)
            return loss, e_heatmap
        elif self.env.mode == "solve":
            return e_heatmap
        else:
            raise ValueError()

    def inference_edge_dense_process(
        self, task: str, nodes_feature: Tensor, x: Tensor, graph: Tensor, 
        e: Tensor, ground_truth: Tensor, nodes_num_list: list, ref_tour_list: list
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
            
        # random init
        e_noised = (torch.randn_like(e) > 0).float()
        
        # denoise steps
        for i in range(self.inference_steps):
            # consistency time
            t1, t2 = self.time_schedule(i)
            t1 = torch.tensor([t1], device=self.device).float()

            # add small random noise
            e_small_noise = 1.0 + self.cm_beta * torch.rand_like(e_noised)
            # e_noised = 2 * (e_noised * e_small_noise) - 1
            e_noised = e_noised * e_small_noise - 0.5
            
            # forward
            x_pred, e_pred = self.model.forward(
                task=task, focus_on_node=False, focus_on_edge=True, 
                nodes_feature=nodes_feature, x=x, edges_feature=graph, 
                e=e_noised, t=t1, edge_index=None
            )
            del x_pred
            
            # softmax
            e_pred_softmax = e_pred.softmax(1)

            # t2
            if t2 != 0:
                e_pred_ber = torch.bernoulli(e_pred_softmax[:, 1, :].clamp(0, 1))
                e_pred_ber_onehot: Tensor = F.one_hot(e_pred_ber.long(), num_classes=2)
                Q_bar = torch.from_numpy(self.diffusion.Q_bar[t2]).float().to(self.device)
                e_prob = torch.matmul(e_pred_ber_onehot.float(), Q_bar)
                e_noised = torch.bernoulli(e_prob[..., 1].clamp(0, 1))

        # heatmap
        e_heatmap = e_pred_softmax[:, 1, :, :]
            
        # return
        if self.env.mode == "val":
            loss_func = nn.CrossEntropyLoss()
            loss = loss_func(e_pred, ground_truth)
            return loss, e_heatmap
        elif self.env.mode == "solve":
            return e_heatmap
        else:
            raise ValueError()

    def inference_node_sparse_process(
        self, task: str, nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, edge_index: Tensor, graph_list: List[Tensor], ground_truth: Tensor,
        nodes_num_list: list, edges_num_list: list, ref_tour_list: list
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        # random init
        x_noised = (torch.randn_like(x) > 0).long().float()
        
        # inference denoise steps
        for i in range(self.inference_steps):
            # consistency time
            t1, t2 = self.time_schedule(i)
            t1 = torch.tensor([t1], device=self.device).float()
            
            # add small random noise
            x_small_noise = 1.0 + self.cm_beta * torch.rand_like(x_noised)
            # x_noised = (2 * x_noised - 1) * x_small_noise
            x_noised = x_noised * x_small_noise - 0.5

            # forward
            x_pred, e_pred = self.model.forward(
                task=task, focus_on_node=True, focus_on_edge=False, 
                nodes_feature=nodes_feature, x=x_noised, 
                edges_feature=edges_feature, e=e, t=t1, edge_index=edge_index
            )
            del e_pred
            
            # softmax
            x_pred_softmax = x_pred.softmax(-1)

            # t2
            if t2 != 0:
                x_pred_ber = torch.bernoulli(x_pred_softmax[..., 1].clamp(0, 1))
                x_pred_ber_onehot: Tensor = F.one_hot(x_pred_ber.long(), num_classes=2)
                Q_bar = torch.from_numpy(self.diffusion.Q_bar[t2]).float().to(self.device)
                x_prob = torch.matmul(x_pred_ber_onehot.float(), Q_bar)
                x_noised = torch.bernoulli(x_prob[..., 1].clamp(0, 1))
    
        # heatmap
        x_heatmap = x_pred_softmax[:, 1]
            
        # return
        if self.env.mode == "val":
            loss = nn.CrossEntropyLoss()(x_pred, ground_truth)
            return loss, x_heatmap
        elif self.env.mode == "solve":
            return x_heatmap
        else:
            raise ValueError()
        
    def inference_node_dense_process(
        self, task: str, nodes_feature: Tensor, x: Tensor, graph: Tensor, 
        e: Tensor, ground_truth: Tensor, nodes_num_list: list
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        raise NotImplementedError()
    
    def plot_log(self):
        meta_loss = self.logger["meta_loss"]
        avg_gap = self.logger["avg_gap"]
        
        fig, axes = plt.subplots(1, len(self.task_pool) + 2, figsize=(5 * (len(self.task_pool) + 2), 5))
        fig.suptitle("Training Logs", fontsize=12)
        
        for i, task in enumerate(self.task_pool):
            obj_values = self.logger[task]
            epochs = np.arange(len(obj_values)) * self.save_n_epochs
            
            ax = axes[i]
            ax.plot(epochs, obj_values, 'b-', label=f'{task}_obj')
            ax.set_title(f"Task: {task}")
            ax.set_xlabel("Outer Epoch")
            ax.set_ylabel("Objective Value")
            ax.grid(True)
        
        ax = axes[-2]
        epochs = list(range(len(meta_loss)))
        ax.plot(epochs, meta_loss, 'r-', label='meta_loss')
        ax.set_title("Meta Loss")
        ax.set_xlabel("Outer Epoch")
        ax.set_ylabel("Loss")
        ax.grid(True)

        ax = axes[-1]
        epochs = np.arange(len(avg_gap)) * self.save_n_epochs
        ax.plot(epochs, avg_gap, 'r-', label='avg_gap')
        ax.set_title("Avg gap")
        ax.set_xlabel("Outer Epoch")
        ax.set_ylabel("Gap")
        ax.grid(True)
        
        plt.tight_layout(rect=[0, 0, 1, 0.97])
        save_plot_path = os.path.join(self.plot_folder, f"latest_curve_{self.task_pool}.pdf")
        os.makedirs(os.path.dirname(save_plot_path), exist_ok=True)
        plt.savefig(save_plot_path)

    def freeze_layers(self, layer_names):
        for name, param in self.model.named_parameters():
            if any(layer_name in name for layer_name in layer_names):
                param.requires_grad = False