import random

from comb_modules.utils import cached_vertex_grid_to_edges, cached_vertex_grid_to_edges_grid_coords
import time
from abc import ABC, abstractmethod

import torch
from comb_modules.losses import HammingLoss
from comb_modules.dijkstra import ShortestPath, get_solver
from logger import Logger
from models import get_model
from utils import AverageMeter, optimizer_from_string, customdefaultdict
from decorators import to_tensor, to_numpy
from . import metrics
from .metrics import compute_metrics
import numpy as np
from collections import defaultdict
def get_trainer(trainer_name):
    trainers = {"DijkstraAttacker": DijkstraAttacker}
    return trainers[trainer_name]
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
from .visualization import draw_paths_on_image

from utils import minimum, maximum
import time
from PIL import Image

class AttackAbstractTrainer(ABC):
    def __init__(
        self,
        *,
        #sample,   # a triple with input, path, weights
        #model,
        train_iterator,
        test_iterator,
        metadata,
        use_cuda,
        batch_size,
        optimizer_name,
        optimizer_params,
        model_params,
        fast_mode,
        neighbourhood_fn,
        preload_batch,
        lr_milestone_1,
        lr_milestone_2,
        use_lr_scheduling
    ):

        self.fast_mode = fast_mode
        self.use_cuda = use_cuda
        self.optimizer_params = optimizer_params
        self.batch_size = batch_size
        #self.test_iterator = test_iterator
        #self.train_iterator = train_iterator

        self.metadata = metadata
        self.grid_dim = int(np.sqrt(self.metadata["output_features"]))
        self.neighbourhood_fn = neighbourhood_fn
        self.preload_batch = preload_batch
        self.model = None
        self.sample = None
        self.delta = None
        self.maxdelta = None
        self.optimizer = None
        self.image_min = None
        self.image_max = None
        # None objects loaded externally

        self.build_model(**model_params)

        if self.use_cuda:
            self.model.to("cuda")
        #self.optimizer = optimizer_from_string(optimizer_name)(self.delta, **optimizer_params)
        self.epochs = 0
        self.maxdelta = 1.0

    def train_epoch(self):
        #self.model.eval()

        self.epochs += 1
        avg_loss = AverageMeter("Loss")
        avg_accuracy = AverageMeter("Accuracy")
        avg_perfect_accuracy = AverageMeter("Perfect Accuracy")
        avg_metrics = customdefaultdict(lambda k: AverageMeter("train_"+k))



        #iterator = self.train_iterator.get_epoch_iterator(batch_size=self.batch_size, number_of_epochs=1, device='cuda' if self.use_cuda else 'cpu', preload=self.preload_batch)
        #for i, data in enumerate(iterator):
        #image, true_path, true_weights = data["images"], data["labels"],  data["true_weights"]
        image, true_path, true_weights = self.sample
        image        = torch.Tensor(image).unsqueeze(0)
        true_path    = torch.Tensor(true_path).unsqueeze(0)
        true_weights = torch.Tensor(true_weights).unsqueeze(0)

        # make sure this is still a parameter after doing the operation
        #with torch.no_grad():
        #self.delta = torch.clamp(self.delta,-self.maxdelta,self.maxdelta)
        #adv_image = image + self.delta.unsqueeze(0)
        # print('before', self.delta.grad)
        # print('item', type(self.delta.grad))
        #print('after', self.delta.grad())

        #with torch.no_grad():
        # print('min', torch.min(image, dim=0).values.shape)
        left_box = image.min() - image
        righ_box = image.max() - image

        left_box = maximum( left_box, -self.maxdelta*torch.ones_like(left_box) )
        righ_box = minimum( righ_box,  self.maxdelta*torch.ones_like(left_box) )
        # print(left_box.shape, left_box)
        delt = torch.clamp(self.delta,left_box.numpy()[0][0][0][0],righ_box.numpy()[0][0][0][0])
        #delt = torch.clamp(self.delta,-self.maxdelta,self.maxdelta)
        #delt = torch.nn.utils.clip_grad_norm_(self.delta, self.maxdelta)
        adv_image = image + delt.unsqueeze(0)

        plotimage = self.metadata["denormalize"]( adv_image )  #np.expand_dims(sample_in,0) )  # .reshape(96,96,3)
        plot_im_arr = np.transpose(  plotimage.squeeze().astype(np.uint8), (2,1,0) )
        im = Image.fromarray(plot_im_arr)
        if self.epochs % 50 ==0:
            im.show()
            time.sleep(0.5)
            im.close()

        loss, accuracy, last_suggestion = self.forward_pass(adv_image, true_path, train=True, i=0)#, i=i)
        suggested_path = last_suggestion["suggested_path"]
        avg_loss.update(loss.item(), image.size(0))
        avg_accuracy.update(accuracy.item(), image.size(0))

        print("loss = {}".format(loss))

        #print("self.optimizer = ")
        #print( self.optimizer )
        #input('waiting')

        # compute gradient and do SGD step
        self.optimizer.zero_grad()

        loss.backward()
        # print("loss.grad after = ")
        # print( loss.grad )
        self.optimizer.step()


        #meters = [batch_time, data_time, cuda_time, avg_loss, avg_accuracy]
        #meter_str = "\t".join([str(meter) for meter in meters])
        #print(f"Epoch: {self.epochs}\t{meter_str}")

        #if self.use_lr_scheduling:
        #    self.scheduler.step()

        return {
            "train_loss": avg_loss.avg,
            "train_accuracy": avg_accuracy.avg,
            "path": suggested_path,
            **{"train_"+k: avg_metrics[k].avg for k in avg_metrics.keys()}
        }


# def get_solver(neighbourhood_fn):
#     def solver(matrix):
#         return dijkstra(matrix, neighbourhood_fn).shortest_path
#
#     return solver




class DijkstraAttacker(AttackAbstractTrainer):
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.l1_regconst = l1_regconst
        self.lambda_val = lambda_val
        self.temp_holder = {}
        # self.temp_holder['lambda_val'] = lambda_val
        # self.temp_holder['solver'] = get_solver(self.neighbourhood_fn)
        # self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn)
        self.solver = ShortestPath.apply
        self.loss_fn = HammingLoss()

        print("META:", self.metadata)
    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )
        #self.model.load_state_dict(state_dict)

    def forward_pass(self, input, true_shortest_paths, train, i):
        output = self.model(input)

        # make grid weights positive
        output = torch.abs(output)
        weights = output.reshape(-1, output.shape[-1], output.shape[-1])

        #print("weights = ")
        #print( weights )


        shortest_paths = self.solver(weights)

        #print("shortest_paths = ")
        #print( shortest_paths )

        #print("true_shortest_paths = ")
        #print( true_shortest_paths )

        loss = self.loss_fn(shortest_paths, true_shortest_paths)
        #print("loss.grad = ")
        #print( loss.grad )

        last_suggestion = {
            "suggested_weights": weights,
            "suggested_path": shortest_paths
        }

        accuracy = (torch.abs(shortest_paths - true_shortest_paths) < 0.5).to(torch.float32).mean()
        extra_loss = self.l1_regconst * torch.mean(output)
        loss += extra_loss

        return loss, accuracy, last_suggestion
