import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import copy
from tqdm import tqdm
import os
import numpy as np
from .utils import cnn, IQLDataset, DEFAULT_DEVICE


class BehaviorCloningModel(nn.Module):
    def __init__(self, num_epochs, batch_size, load_path):
        super().__init__()
        self.num_epochs = num_epochs * batch_size
        self.batch_size = batch_size
        self.load_path = load_path

    def set_init(self, param_dict, learning_rate=3e-4):

        in_channels = param_dict['env'].in_channels
        num_actions = param_dict['env'].num_actions_

        self.bc = cnn(in_channels, num_actions).to(DEFAULT_DEVICE)

        if os.path.exists(self.load_path):
            checkpoint = torch.load(self.load_path)
            self.bc.load_state_dict(checkpoint)

        else:
            self.optimizer = optim.Adam(self.bc.parameters(), lr=learning_rate)
            self.criterion = nn.CrossEntropyLoss()

            self.update(param_dict['dataset'])
            torch.save(self.bc.state_dict(), self.load_path)


    def update(self, dataset):
        iqldataset = IQLDataset(dataset, self.num_epochs)
        pin_memory = False if torch.cuda.is_available() else True
        data_loader = DataLoader(
            iqldataset, 
            batch_size=self.batch_size, 
            shuffle=True,
            pin_memory=pin_memory,
        )

        # for batch in data_loader:
        for batch in tqdm(data_loader):
            obs = batch['states']
            action = batch['actions']

            self.bc.train()
            self.optimizer.zero_grad()
            outputs = self.bc(obs)
            
            loss = self.criterion(outputs, action.long().squeeze(1).squeeze(1))
            loss.backward()
            self.optimizer.step()
    
    def train(self, train_state, all_state):

        
        train_set = set(map(tuple, train_state))
        diff = np.array([row for row in all_state if tuple(row) not in train_set])

        # breakpoint()
        # diff = np.setdiff1d(all_state, train_state)

        # return None if diff.size == 0 else il_agent(self.bc, diff)
        return il_agent(self.bc, diff)
    

class il_agent:
    def __init__(self, bc, seen_state):
        self.Q = None
        self.seen_state = seen_state
        self.bc = bc
        self.bc.eval()

    def policy(self, s):
        with torch.no_grad():
            s = torch.from_numpy(s).to(DEFAULT_DEVICE)
            prob = torch.softmax(self.bc(s), dim=1)
            index = torch.multinomial(prob, num_samples=1)
            return index.cpu().detach().numpy()
    
