import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms
from torch.nn.parallel import DataParallel

import torchmetrics

from injectionModel import resnet18, resnet34, resnet50, resnet101, resnet152
import statistics
import wandb

from AffectNet.AffectNet_Dataloader import AffectNetDataLoader
from FER.FER2013_Dataloader import FER2013DataLoader

import ssl
ssl._create_default_https_context = ssl._create_unverified_context


device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

def make_state_for_MLP_agent(queues):

    train_loss_queue = queues.train_loss_queue.queue
    train_f1_queue = queues.train_f1_queue.queue
    val_loss_queue = queues.val_loss_queue.queue
    val_f1_queue = queues.val_f1_queue.queue

    state = train_loss_queue + train_f1_queue + val_loss_queue + val_f1_queue
    return state

def make_state_for_LSTM_agent(queues,QUEUE_ITER_SIZE):

    train_loss_queue = queues.train_loss_queue.queue
    train_f1_queue = queues.train_f1_queue.queue
    val_loss_queue = queues.val_loss_queue.queue
    val_f1_queue = queues.val_f1_queue.queue

    delta1 = int(len(train_loss_queue) / QUEUE_ITER_SIZE)
    delta2 = int(len(train_f1_queue) / QUEUE_ITER_SIZE)
    delta3 = int(len(val_loss_queue)/ QUEUE_ITER_SIZE)
    delta4 = int(len(val_f1_queue) / QUEUE_ITER_SIZE)

    lstm_one_data = []
    for i in range(QUEUE_ITER_SIZE):
        data1 = train_loss_queue[i: i + delta1]
        data2 = train_f1_queue[  i: i + delta2]
        data3 = val_loss_queue[  i: i + delta3]
        data4 = val_f1_queue[    i: i + delta4]

        i_seq_data = data1 + data2 + data3 + data4
        lstm_one_data.append(i_seq_data)
    state = lstm_one_data
    return state



class Environment():
    def __init__(self,ACTION_IDX, TRAIN_STEP_PER_ITER, VAL_STEP_PER_ITER, child_model_batch_size, child_model_lr, child_model, QUEUE_ITER_SIZE, agent, GPUparallel,dataset,today,types,checkpointFolderPath):
        super(Environment,self).__init__()

        # Hyper Params
        self.ACTION_IDX = ACTION_IDX
        self.TRAIN_STEP_PER_ITER = TRAIN_STEP_PER_ITER
        self.VAL_STEP_PER_ITER = VAL_STEP_PER_ITER
        self.child_model_batch_size = child_model_batch_size
        self.child_model_lr = child_model_lr
        self.child_model = child_model
        self.agent = agent
        self.dataset = dataset
        self.type = types
        self.checkpointFolderPath = checkpointFolderPath

        self.QUEUE_ITER_SIZE = QUEUE_ITER_SIZE
        self.GPUparallel = GPUparallel
        if self.agent == "MLP_agent" or "EnsembledMLP_agent":
            self.make_state_for_MLP_agent = make_state_for_MLP_agent
        elif self.agent == "LSTM_agent":
            self.make_state_for_LSTM_agent = make_state_for_LSTM_agent


        # Dataset 정의
        if dataset == "CIFAR100":

            self.transformer = transforms.Compose([transforms.Resize((240, 240)),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                          ])
        
            self.training_data = torchvision.datasets.CIFAR100(root="C:/Users/dabmv/PycharmProjects/pythonProject/dataset",
                                                        train=True,
                                                        transform=self.transformer, target_transform=None, download=True)
            self.test_data = torchvision.datasets.CIFAR100(root="C:/Users/dabmv/PycharmProjects/pythonProject/dataset", train=False,
                                                    transform=self.transformer, target_transform=None, download=True)

            self.train_dataloader = DataLoader(self.training_data, batch_size=self.child_model_batch_size, shuffle=True)
            self.test_dataloader = DataLoader(self.test_data, batch_size=self.child_model_batch_size, shuffle=True)

            self.classNum=100

        elif dataset == "FER2013":
            self.train_dataloader, self.test_dataloader,self.none  = FER2013DataLoader(self.child_model_batch_size)
            self.classNum=7

        elif dataset == "AffectNet" :
            self.train_dataloader, self.test_dataloader = AffectNetDataLoader(self.child_model_batch_size)
            self.classNum=8
        

        self.train_data_size = len(self.train_dataloader.dataset)
        self.train_batch_size = self.train_dataloader.batch_size
        self.total_train_iteration = int(self.train_data_size / self.train_batch_size)
        print("total_train_iteration : ",self.total_train_iteration)
        self.train_step = 0


        # 학습에 사용할 CPU나 GPU, MPS 장치를 얻습니다.
        self.device = (
            "cuda"
            if torch.cuda.is_available()
            else "mps"
            if torch.backends.mps.is_available()
            else "cpu"
        )
        print(f"Using {self.device} device")

        if child_model == "resnet34" or child_model == "resent50":
            self.model = resnet34(self.classNum).to(self.device)
        elif child_model == "resnet18":
            self.model = resnet18(self.classNum).to(self.device)
        elif child_model == "resnet101":
            self.model = resnet101(self.classNum).to(self.device)

        if self.GPUparallel:
            self.model = DataParallel(self.model)

        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.child_model_lr, weight_decay=0.0001)
        self.train_F1 = torchmetrics.F1Score(task="multiclass", num_classes=100).to(self.device)

        self.current_train_score_mean = 0
        self.current_val_score_mean = 0

        self.MaxF1 = 0
        self.today = today

        self.steps = 1
        self.timing = 25000

    def step(self, given_action, queues):
        #done setting
        done = 0
        #injection module setting
        if self.GPUparallel:
            self.model.module.envSet(given_action, self.ACTION_IDX)
        else:
            self.model.envSet(given_action, self.ACTION_IDX)

        #collect training and validation data
        self.current_train_score_mean = self.train(queues.train_loss_queue, queues.train_f1_queue)
        self.current_val_score_mean = self.validation(queues.val_loss_queue, queues.val_f1_queue)

        #make reward(minus)
        reward = self.current_val_score_mean - self.current_train_score_mean
        wandb.log({"reward": reward})

        #val
        self.train_step = self.train_step + self.TRAIN_STEP_PER_ITER
        if  self.train_step > self.total_train_iteration:
            self.train_step = 0
            self.validation_all()
            done = 1

        if self.agent == "MLP_agent" or "EnsembledMLP_agent":
            next_state = self.make_state_for_MLP_agent(queues)
        elif self.agent == "LSTM_agent":
            next_state = self.make_state_for_LSTM_agent(queues, self.QUEUE_ITER_SIZE)

        return reward, next_state ,done


    def train(self,train_loss_queue,train_f1_queue):
        self.model.train()

        F1s=[]

        for batch,(X,y) in enumerate(self.train_dataloader):
            if batch == self.TRAIN_STEP_PER_ITER:
                break
            X, y = X.to(self.device), y.to(self.device)
            pred = self.model(X)
            loss = self.loss_fn(pred, y)
            F1 = self.train_F1(pred, y)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            train_loss_queue.enqueue(torch.tensor(loss.item()).to(self.device))
            train_f1_queue.enqueue(torch.tensor(F1.item()).to(self.device))

            wandb.log({"train loss": loss.item()})

            F1s.append(F1.item())

            #injection 
            if self.steps % self.timing == 0 :
                self.model.module.envSetPeriodicTransferLearning()
             

            self.steps = self.steps + 1
    

        meanF1 = sum(F1s) / len(F1s)
        return meanF1

    def validation(self,val_loss_queue,val_f1_queue):
        self.model.eval()
        with torch.no_grad():
            F1s = []
            losses = []
            for batch,(X,y) in enumerate(self.test_dataloader):
                if batch == self.VAL_STEP_PER_ITER:
                    break
                X, y = X.to(self.device), y.to(self.device)

                pred = self.model(X)
                loss = self.loss_fn(pred, y)
                F1 = self.train_F1(pred, y)

                val_loss_queue.enqueue(loss.item())
                val_f1_queue.enqueue(F1.item())

                losses.append(loss.item())
                F1s.append(F1.item())

                wandb.log({"instant test loss": loss})

            meanF1 = sum(F1s)/len(F1s)
            wandb.log({"instant  test f1": meanF1})
            return meanF1

    def validation_all(self):
        #self.model.envInit()
        self.model.eval()
        with torch.no_grad():
            F1s = []
            for X, y in self.test_dataloader:
                X, y = X.to(self.device), y.to(self.device)
                pred = self.model(X)
                loss = self.loss_fn(pred, y).item()
                F1 = self.train_F1(pred, y)
                F1s.append(F1.item())
                wandb.log({"test loss": loss})
            F1 = sum(F1s)/len(F1s)
            wandb.log({"test f1": F1})

            if F1 > self.MaxF1:
                
                self.MaxF1 = F1

                torch.save(self.model.state_dict(),f"{self.checkpointFolderPath}/{self.child_model_batch_size}+{self.child_model_lr}+{self.child_model}+{self.dataset}+{self.today}+{self.type}.pth" )


