import os
import torch
from torch import nn
from torch.nn import functional as F
from utils import utils
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
import logging
import pdb 
import sklearn.metrics

from utils.datautils import *
from torch.utils.tensorboard import SummaryWriter

class USSLEval(nn.Module):
    def __init__(self, simclr, args):
        super().__init__()
        self.args = args
        
        self.simclr = simclr
        
        self.writer = SummaryWriter(log_dir = self.args.root)
        
        logging.info(f"Loading checkpoint from {self.args.ckpt}")
        print(f"Loading checkpoint from {self.args.ckpt}")
        map_location = {'cuda:%d' % 0: 'cuda:%d' % self.args.rank}
        ckpt = torch.load(self.args.ckpt, map_location=map_location)
        
        for key in list(ckpt['encoder'].keys()):
            if 'module.' in key:
                ckpt['encoder'][key.replace('module.', '')] = ckpt['encoder'].pop(key)
                        
        self.simclr.encoder.module.load_state_dict(ckpt['encoder'])
        self.simclr.encoder.eval()
        
        for p in self.simclr.parameters():
            p.requires_grad = False
        
        h_dim = self.simclr.encoder.module.encoder_dim
        n_classes = get_classes(dataset = self.args.task)
        
        if self.args.use_mlp_classifier:
            self.classifier = nn.Sequential(
                nn.Linear(h_dim, h_dim*2),
                nn.ReLU(),
                nn.Linear(h_dim*2, h_dim), 
                nn.ReLU(),
                nn.Linear(h_dim, n_classes)
            )          
        if self.args.use_conv_classifier:
            self.classifier = nn.Sequential(
                nn.ConvTranspose2d(h_dim // 4, self.args.ndf * 8, 3, 2),  
                nn.ReLU(True),
                nn.ConvTranspose2d(self.args.ndf * 8, self.args.ndf * 4, 2, 2),  
                nn.ReLU(True),
                nn.ConvTranspose2d(self.args.ndf * 4, self.args.ndf * 2, 2, 2, 1),
                nn.ReLU(True),
                nn.Flatten(),
                nn.Linear(18*18*self.args.ndf*2, n_classes)
            )
        else:
            self.classifier = nn.Linear(h_dim, n_classes)
            self.classifier.weight.data.zero_()
            self.classifier.bias.data.zero_()
          

        self.classifier = self.set_parallel_device(self.classifier)
            
        self.cur_iter = 0
        self.best_train_acc = 0
        self.best_test_acc = 0
        
        self.set_loaders()
        
        self.optimizer, self.scheduler = utils.get_optimizer_scheduler(opt = self.args.classifier_opt, 
                                                                 LR = self.args.classifier_lr, 
                                                                 weight_decay = self.args.classifier_weight_decay, 
                                                                 lr_schedule = self.args.classifier_lr_schedule, 
                                                                 warmup = self.args.classifier_warmup, 
                                                                 T_max = self.args.iters, 
                                                                 model = self.classifier, 
                                                                 last_epoch=self.cur_iter-1)
        
    def set_parallel_device(self, component):
        component = component.cuda(self.args.gpu)
        
        if self.args.dist == 'ddp':
            if self.args.sync_bn:
                component = nn.SyncBatchNorm.convert_sync_batchnorm(component)
            dist.barrier()
            component = DDP(component, [self.args.gpu], find_unused_parameters=True)

        else:
            component = nn.DataParallel(component)
            
        return component
    
    def step(self, images, labels):
        if self.args.task == 'imagenet':
            images = images / 255.
            
        h, z = self.simclr.encoder(images) 
        if self.args.use_conv_classifier:
            h = h.view(h.shape[0], self.simclr.encoder.module.encoder_dim // 4, 2, 2)
            
        p = self.classifier(h)
        
        loss = F.cross_entropy(p, labels)
        return {
            'eval_loss': loss
        }
    
    def set_loaders(self): 
        train_transform, test_transform = get_linear_model_transforms(dataset_name = self.args.task, 
                                                            image_size = self.args.image_size, 
                                                            color_dist_s = self.args.color_dist_s, 
                                                            scale_lower = self.args.scale_lower)
        
        train_set, test_set = get_dataset(dataset_name = self.args.task, 
                                          data_root = self.args.data_root, 
                                          train_transform = train_transform,
                                          test_transform = test_transform) 
        
        self.trainsampler = get_sampler(dataset = train_set, 
                                          dist = self.args.dist)
        
        self.testsampler = get_sampler(dataset = test_set, 
                                          dist = self.args.dist)
            
        self.train_loader = get_dataloader(dataset = train_set,
                                           sampler = self.trainsampler, 
                                          batch_size = self.args.batch_size,
                                          workers = self.args.workers)
        
        self.test_loader = get_dataloader(dataset = test_set,
                                           sampler = self.testsampler, 
                                          batch_size = self.args.batch_size,
                                          workers = self.args.workers)
        
    def test(self, loader):
        self.classifier.eval()
        y_true = np.array([])
        y_pred = np.array([])
        
        probs = np.array([])
        for i, (images, labels) in enumerate(loader):
            if self.args.dist == 'ddp':
                self.testsampler.set_epoch(i)
                
            images = images.cuda(self.args.gpu)
#             print(images.shape, labels)
            
            h, z = self.simclr.encoder(images)   
            if self.args.use_conv_classifier:
                h = h.view(h.shape[0], self.simclr.encoder.module.encoder_dim // 4, 2, 2)
         
            p = self.classifier(h)
            
            y_true = np.append(y_true, labels)
            y_pred = np.append(y_pred, p.argmax(1).detach().cpu().numpy())
            probs = p.detach().cpu().numpy() if len(probs) == 0 else np.append(probs, p.detach().cpu().numpy(), axis = 0)  

        acc = (y_true == y_pred).astype(float).mean()
        self.classifier.train()
        return acc
           
    def train(self, mode = True):
        while self.cur_iter < self.args.iters:
            train_logs = [] 
            
            for _, (images, labels) in enumerate(self.train_loader):  
                images = images.cuda(self.args.gpu)
                labels = labels.cuda(self.args.gpu).cuda(self.args.gpu)
                
                self.cur_iter += 1
                if self.args.dist == 'ddp':
                    self.trainsampler.set_epoch(self.cur_iter)
                
                logs = self.step(images, labels)
                loss = logs['eval_loss']

                # gradient step
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                # save logs for the batch
                train_logs.append({k: utils.tonp(v) for k, v in logs.items()})

                if self.scheduler is not None:
                    self.scheduler.step()

                if (self.cur_iter % self.args.log_freq == 0 or self.cur_iter == self.args.iters):
                    train_logs = utils.agg_all_metrics(train_logs)

                    train_acc = self.test(self.train_loader)
                    test_acc = 0
#                     test_acc = self.test(self.test_loader)

                    self.best_train_acc = train_acc if train_acc > self.best_train_acc else self.best_train_acc
                    self.best_test_acc = test_acc if test_acc > self.best_test_acc else self.best_test_acc

                    if self.args.rank == 0:
                        print(f"Train_acc: {train_acc}, Best_train_acc: {self.best_train_acc}")
                        print(f"Test_acc: {test_acc}, Best_test_acc: {self.best_test_acc}")

                        train_logs['train_acc'] = train_acc
                        train_logs['best_train_acc'] = self.best_train_acc
                        train_logs['test_acc'] = test_acc
                        train_logs['best_test_acc'] = self.best_test_acc

                        self.writer.add_scalar('train_acc', train_acc, global_step=self.cur_iter)
                        self.writer.add_scalar('test_acc', test_acc, global_step=self.cur_iter)
                        self.writer.add_scalar('best_train_acc', self.best_train_acc, global_step=self.cur_iter)
                        self.writer.add_scalar('best_test_acc', self.best_test_acc, global_step=self.cur_iter)
                        self.writer.add_scalar('eval_learning_rate', self.optimizer.param_groups[0]['lr'], global_step=self.cur_iter)

                        for key in train_logs:
                            self.writer.add_scalar(key, train_logs[key], global_step=self.cur_iter)

                        logging.info(f"Iter: {self.cur_iter}/{self.args.iters}, {train_logs}")

                    train_logs = []

                
                if self.cur_iter == self.args.iters:                    
                    break