import copy
import random
import time
from multiprocessing import Process
import numpy as np
import torch

from AbstractClass.AbstractMetaStructure import AbstractMetaStructure, AbstractLossFunc
from AbstractClass.AbstractModel import AbstractModel
from AbstractClass.TaskRelatedClasses import AbstractTaskConsumer
from torch.utils.data import DataLoader

from SourceCode.ModelModule.LossFunc import AutomaticWeightedLoss


class MetaGraphSketch(AbstractMetaStructure):

    def __init__(self, task_producer, task_consumer: AbstractTaskConsumer, model: AbstractModel,
                 loss_func: AbstractLossFunc, device,
                 optimizer, logger, checkpointing_task_num=50):
        self.task_consumer = task_consumer
        self.task_producer = task_producer
        self.device = device
        self.model = model
        self.loss_func = loss_func
        self.optimizer = optimizer
        self.logger = logger
        self.checkpointing_task_num = checkpointing_task_num

    def optimize_one_step(self,step,q,pass_cuda_tensor,support_batch_size, query_batch_size):
        if step % self.logger.eval_gap == 1:
            self.logger.logging(self.model, step)
        meta_task = self.task_consumer.consume_train_task(q, pass_cuda_tensor)
        # print('Queue size:',q.qsize())
        support_set = meta_task.support_set
        query_set = meta_task.query_set
        print(step, ' step train begin, item size:',support_set.support_y.shape[0],end='  ')
        # stream_length = None
        with torch.no_grad():
            stream_length = support_set.support_y.sum()
        while len(support_set) % support_batch_size == 1:
            if random.random() > 0.5:
                support_batch_size += 1
            else:
                support_batch_size -= 1
        while len(query_set) % query_batch_size == 1:
            if random.random() > 0.5:
                query_batch_size += 1
            else:
                query_batch_size -= 1
        self.model.clear()
        if support_batch_size >= len(support_set):
            self.model.write(support_set.support_x, support_set.support_y)
        else:
            support_data_loader = DataLoader(support_set, batch_size=support_batch_size)
            for i, (support_x, support_y) in enumerate(support_data_loader):
                self.model.write(support_x, support_y)
        if query_batch_size >= len(query_set):
            query_pred = self.model.query(query_set.query_x,
                                          stream_length.unsqueeze(-1).repeat(query_set.query_x.shape[0], 1))
            loss = self.loss_func(query_pred, query_set.query_y)
        else:
            query_data_loader = DataLoader(query_set, batch_size=query_batch_size, shuffle=True)
            loss_list = []
            for j, (query_x, query_y) in enumerate(query_data_loader):
                query_pred = self.model.query(query_x, stream_length.unsqueeze(-1).repeat(query_x.shape[0], 1))
                one_loss = self.loss_func(query_pred, query_y)
                loss_list.append(one_loss)
            loss_sum = torch.stack(loss_list)
            loss = loss_sum.sum() / len(loss_list)
        print('loss:',loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.)
        self.optimizer.step()
        self.optimizer.zero_grad()
        torch.cuda.empty_cache()
        self.model.normalize_attention_matrix()
        self.task_consumer.del_meta_task(meta_task)

    def train(self, train_step, support_batch_size, query_batch_size, pass_cuda_tensor=False, queue_size=20,
              pre_ratio_list=None, pre_item_lower_list=None, pre_item_upper_list=None,next_loss=None,change_loss_step = None):
        assert (not (next_loss is None) ^(change_loss_step is None)), 'change loss parameter error'
        # pre for speed up
        if pre_ratio_list is not None:
            pre_step_list = [int(ratio * train_step) for ratio in pre_ratio_list]
            q = torch.multiprocessing.Queue(queue_size)
            p = Process(target=self.task_producer.produce_train_task, args=(q, pass_cuda_tensor, pre_item_lower_list,
                                                                            pre_item_upper_list, pre_step_list,
                                                                            self.checkpointing_task_num))
            p.start()
            self.model.memory_matrix.device = self.device
            self.model.to(self.device)
            if next_loss is not None:
                next_loss.to(self.device)
            # add tasks for pre checkpointing
            for pre_step in pre_step_list:
                self.logger.clean_tasks_for_checkpoingting()
                for i in range(self.checkpointing_task_num):
                    meta_task = self.task_consumer.consume_train_task(q, pass_cuda_tensor)
                    self.logger.add_tasks_for_checkpointing(meta_task)
                print('begin pre step')
                for i in range(pre_step):
                    step = i + 1
                    self.optimize_one_step(step, q, pass_cuda_tensor, support_batch_size, query_batch_size)
            main_step = train_step - sum(pre_step_list)
        # not pre for speed up
        else:
            q = torch.multiprocessing.Queue(queue_size)
            p = Process(target=self.task_producer.produce_train_task, args=(q, pass_cuda_tensor))
            p.start()
            self.model.memory_matrix.device = self.device
            self.model.to(self.device)
            if next_loss is not None:
                next_loss.to(self.device)
            main_step = train_step
        # add tasks for checkpointing
        self.logger.clean_tasks_for_checkpoingting()
        for i in range(self.checkpointing_task_num):
            meta_task = self.task_consumer.consume_train_task(q, pass_cuda_tensor)
            self.logger.add_tasks_for_checkpointing(meta_task)
        print('begin main step')
        for i in range(main_step + 1):
            step = i + 1
            if step == change_loss_step:
                self.loss_func = next_loss
                self.logger.loss_func = next_loss
                print('change loss !')

            self.optimize_one_step(step, q, pass_cuda_tensor, support_batch_size, query_batch_size)
        p.terminate()
        self.logger.close_all_file()

