from loss_functions import *
from utils import *
from data_utils import sample
import time
from tqdm import tqdm
from os import path
import numpy as np

class EncoderTrainer:
    def __init__(
            self,
            args,
            enc,
            train_dataloader,
            opt,
            schedular,
            dec=None,
            action_predictor=None
            ):
        self.args = args
        self.enc = enc
        self.train_dataloader = train_dataloader
        self.opt = opt
        self.schedular = schedular
        self.device = get_device(enc)
        self.origin_of_rotation = torch.zeros(
            args.code_size, dtype=torch.float32, device=self.device
        )
        self.dec = dec
        self.action_predictor = action_predictor


    def train_symmetryreg(self):
        avg_loss_list = []
        avg_loss_equiv_list = []
        avg_loss_barrier_list = []
        avg_loss_ortho_list = []
        for t in range(self.args.num_epochs):
            self.enc.train()
            loss_list = []
            loss_equiv_list = []
            loss_barrier_list = []
            loss_ortho_list = []

            time_start = time.time()
            progress = tqdm(enumerate(self.train_dataloader),
                 desc='Loss: None | Loss Equiv: None | Loss barrier: None | Loss ortho: %12g | L2 Grads: 0'
                 , total=len(self.train_dataloader), position=0, leave=True
            )
            for _, (data) in progress:
                if self.args.sample_training_data:
                    x, actions = sample(
                        self.args.data_dir,
                        self.args.train_cluster_size,
                        self.args.train_num_actions
                    )
                    x = [x]
                    batch_size = 1
                else:
                    x = data[0]
                    actions = data[1]
                    batch_size = x.shape[0]
                # loss_equiv, loss_barrier = loss_fn(
                #     self.enc, x_list, self.origin_of_rotation, self.args.barrier_type,
                #     self.args.hinge_thresh, self.args.cosine_sim, self.args.conformal_map,
                #     self.args.rotation_map, self.args.decompositions
                # )

                loss_equiv = 0.
                loss_barrier = 0.
                loss_extra = 0.

                for i in range(batch_size):
                    if self.args.use_action_pred:
                        loss_eq, loss_bar, loss_or = loss_fn_ed_inverse_action(
                            self.enc, self.action_predictor, x[i], actions[i],
                            self.origin_of_rotation, self.args.barrier_type,
                            self.args.hinge_thresh, self.args.cosine_sim,
                            self.args.conformal_map, self.args.rotation_map,
                            self.args.decompositions
                        )
                    else:
                        loss_eq, loss_bar = loss_fn_ed(
                            self.enc, x[i], self.args.barrier_type,
                            self.args.hinge_thresh, self.args.cosine_sim, self.args.conformal_map,
                            self.args.rotation_map, self.args.decompositions,
                            self.origin_of_rotation
                        )
                        loss_or = torch.zeros(1).to(self.device)
                    # loss_eq, loss_bar, loss_or = loss_fn_cosine_so3(
                    #     self.enc, x[i], self.args.barrier_type, self.args.hinge_thresh
                    # )
                    # loss_eq, loss_bar, loss_or = loss_fn_cosine_so3_inverse_action(
                    #     self.enc, x[i], actions
                    # )
                    loss_equiv += loss_eq / batch_size
                    loss_barrier += loss_bar / batch_size
                    loss_extra += loss_or / batch_size
                loss = loss_equiv + loss_extra + loss_barrier
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                # I multiply the losses by 1e4 to get nicer numbers
                loss_list.append(1e4 * loss.item())
                loss_equiv_list.append(1e4 * loss_equiv.item())
                loss_barrier_list.append(1e4 * loss_barrier.item())
                loss_ortho_list.append(1e4 * loss_extra.item())

                progress.set_description(
                    'Loss: %12g | Loss Equiv: %12g | Loss barrier: %12g | Loss extra: %12g | L2 Grads: %12g' % (
                    loss_list[-1],
                    loss_equiv_list[-1],
                    loss_barrier_list[-1],
                    loss_ortho_list[-1],
                    get_grads_norm(self.enc.parameters(), norm_type=2.0)
                ))
            time_end = time.time()

            avg_loss = np.mean(loss_list)
            avg_loss_equiv = np.mean(loss_equiv_list)
            avg_loss_barrier = np.mean(loss_barrier_list)
            avg_loss_ortho = np.mean(loss_ortho_list)
            avg_loss_list.append(avg_loss)
            avg_loss_equiv_list.append(avg_loss_equiv)
            avg_loss_barrier_list.append(avg_loss_barrier)
            avg_loss_ortho_list.append(avg_loss_ortho)
            print('\nEpoch %3d | Loss: %12g | Loss Equiv: %12g | Loss barrier: %12g | Loss extra: %12g | Time: %6.1f sec' % (
                t + 1, avg_loss, avg_loss_equiv, avg_loss_barrier, avg_loss_ortho, time_end - time_start))
            save_path = path.join(self.args.checkpoint_dir, 'model_recent.tar')
            print('saving model to %s' % save_path)
            torch.save(self.enc.state_dict(), save_path)
            #self.schedular.step()
            print('Learning rate:', self.opt.param_groups[0]["lr"])
        save_path = path.join(self.args.checkpoint_dir, 'model_final.tar')
        print('saving model to %s' % save_path)
        torch.save(self.enc.state_dict(), save_path)

    def train_autoencoder(self):
        mse = torch.nn.MSELoss()
        bce = torch.nn.BCELoss()
        avg_loss_list = []
        for t in range(self.args.num_epochs):
            self.enc.train()
            loss_list = []

            time_start = time.time()
            progress = tqdm(enumerate(self.train_dataloader),
                 desc='Loss: None | Loss Rec: None ',
                 total=self.args.steps_per_epoch, position=0, leave=True)
            for _, (data) in progress:
                if self.args.sample_training_data:
                    x, actions = sample(
                        self.args.data_dir,
                        self.args.train_cluster_size,
                        self.args.train_num_actions
                    )
                    x = [x]
                else:
                    x = data[0]
                    actions = data[1]
                if torch.is_tensor(x):
                    x = x.to(self.device)
                else:
                    x = torch.Tensor(np.array(x)).to(self.device)
                x = x[0]
                x = x.reshape(-1, x.shape[2], x.shape[3], x.shape[4])
                x_rec = self.dec(self.enc(x))
                #loss = mse(x_rec, x/255.0)
                #loss = mse(F.sigmoid(x_rec), x/255.0)
                loss = mse(torch.tanh(x_rec), x / 255.0)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                # I multiply the losses by 1e4 to get nicer numbers
                loss_list.append(1e4 * loss.item())
                progress.set_description(
                    'Loss: %12g | L2 Weights: %12g | L2 Grads: %12g' % (
                    loss_list[-1],
                    get_weights_norm(self.enc.parameters(), norm_type=2.0),
                    get_grads_norm(self.enc.parameters(), norm_type=2.0)
                ))
            time_end = time.time()

            avg_loss = np.mean(loss_list)
            avg_loss_list.append(avg_loss)
            print('\nEpoch %3d | Loss: %12g | Time: %6.1f sec' % (
                t + 1, avg_loss, time_end - time_start))
            save_path = path.join(self.args.checkpoint_dir, 'model_recent.tar')
            print('saving model to %s' % save_path)
            torch.save(self.enc.state_dict(), save_path)
            self.schedular.step()
        save_path = path.join(self.args.checkpoint_dir, 'model_final.tar')
        print('saving model to %s' % save_path)
        torch.save(self.enc.state_dict(), save_path)

    def train_DRLIM(self):
        avg_loss_list = []
        avg_loss_inv_list = []
        avg_loss_barrier_list = []
        torch.autograd.set_detect_anomaly(True)
        for t in range(self.args.num_epochs):
            self.enc.train()
            loss_list = []
            loss_inv_list = []
            loss_barrier_list = []

            time_start = time.time()
            progress = tqdm(enumerate(self.train_dataloader),
                 desc='Loss: None | Loss Inv: None | Loss barrier: None'
                 , total=len(self.train_dataloader), position=0, leave=True
            )
            for _, (data) in progress:
                x_1, x_2 = data
                x = torch.cat([x_1, x_2], dim=0)
                self.opt.zero_grad()
                loss_inv, loss_barrier = loss_contrastive(
                    self.enc, x, self.args.hinge_thresh,
                    self.args.temperature, self.args.cosine_sim
                )
                loss = loss_inv + loss_barrier
                loss.backward()
                self.opt.step()
                # I multiply the losses by 1e4 to get nicer numbers
                loss_list.append(loss.item())
                loss_inv_list.append(loss_inv.item())
                loss_barrier_list.append(loss_barrier.item())

                progress.set_description(
                    'Loss: %12g | Loss Inv: %12g | Loss barrier: %12g' % (
                    loss_list[-1],
                    loss_inv_list[-1],
                    loss_barrier_list[-1]
                ))
            time_end = time.time()

            avg_loss = np.mean(loss_list)
            avg_loss_inv = np.mean(loss_inv_list)
            avg_loss_barrier = np.mean(loss_barrier_list)
            avg_loss_list.append(avg_loss)
            avg_loss_inv_list.append(avg_loss_inv)
            avg_loss_barrier_list.append(avg_loss_barrier)
            print('\nEpoch %3d | Loss: %12g | Loss Inv: %12g | Loss barrier: %12g | Time: %6.1f sec' % (
                t + 1, avg_loss, avg_loss_inv, avg_loss_barrier, time_end - time_start))
            save_path = path.join(self.args.checkpoint_dir, 'model_recent.tar')
            print('saving model to %s' % save_path)
            torch.save(self.enc.state_dict(), save_path)
            #self.schedular.step()
            print('Learning rate:', self.opt.param_groups[0]["lr"])
        save_path = path.join(self.args.checkpoint_dir, 'model_final.tar')
        print('saving model to %s' % save_path)
        torch.save(self.enc.state_dict(), save_path)
